diff --git a/.gitignore b/.gitignore index 82f9275..2df0edd 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,13 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +*.DS_Store* + +*outputs/ +*unsloth_compiled_cache/ +*.hatchet +*.nsys-rep +*trainer_output/ +*trace.json +*unsloth_trace*.json diff --git a/data_grapher.py b/data_grapher.py new file mode 100644 index 0000000..9a6e53a --- /dev/null +++ b/data_grapher.py @@ -0,0 +1,55 @@ +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + +# 1) Load your data +df = pd.read_csv('all_results.csv') + +# 2) Compute mean of the four methods per script +metrics = ['baseline', 'proton', 'torch', 'nsys'] +grouped = ( + df + .groupby('script')[metrics] + .mean() + .reset_index() +) +print(grouped) + +# 3) Compute % difference from baseline +# (value - baseline) / baseline * 100 +for m in metrics: + # print(grouped[m]) + grouped[m] = (grouped[m] - grouped['baseline']) / grouped['baseline'] * 100 + print(m) + +# 4) Prepare for plotting +x_labels = grouped['script'] +data = grouped[metrics] +n_groups = len(x_labels) +n_metrics = len(metrics) +bar_width = 0.8 / n_metrics +x = np.arange(n_groups) + +# 5) Create the grouped bar chart +fig, ax = plt.subplots(figsize=(10, 5)) + +for i, m in enumerate(metrics): + ax.bar( + x + i * bar_width, + data[m], + width=bar_width, + label=m + ) + +# 6) Formatting +ax.set_xlabel('Script') +ax.set_ylabel('Percent Difference from Baseline (%)') +ax.set_title('Average Performance Difference from Baseline by Script') +ax.set_xticks(x + bar_width*(n_metrics-1)/2) +ax.set_xticklabels(x_labels, rotation=30, ha='right') +ax.axhline(0, color='black', linewidth=0.8, linestyle='--') # reference line at 0% +ax.legend(title='Implementation') + +plt.tight_layout() +plt.savefig('overhead_bench.png', dpi=300) +plt.show() diff --git a/end2end/graph.py b/end2end/graph.py new file mode 100644 index 0000000..315397a --- /dev/null +++ b/end2end/graph.py @@ -0,0 +1,109 @@ +import re +import csv +import sys +import argparse +import matplotlib.pyplot as plt +import numpy as np + +def parse_experiments(log_path): + # Patterns for experiment headers + first_hdr = re.compile(r"Timing\s+(?P\S+)\s") + prof_hdr = re.compile(r"^(?PNSYS|PROTON|TORCH)\s*$") + run_pattern = re.compile(r"Run\s*#?(?P\d+)[: of]*.*?(?P\d+\.\d+)\s*seconds") + + experiments = [] + current = None + current_model = None + + with open(log_path, 'r', encoding='utf-8') as f: + for raw_line in f: + line = raw_line.strip() + + # Detect first header: Timing NONE + m1 = first_hdr.search(line) + if m1: + current_model = m1.group('model') + current = {'model': current_model, 'profiler_type': 'NONE', 'runs': []} + experiments.append(current) + continue + + # Detect subsequent headers: NSYS, PROTON, TORCH + m2 = prof_hdr.search(line) + if m2 and current_model: + profiler = m2.group('profiler') + current = {'model': current_model, 'profiler_type': profiler, 'runs': []} + experiments.append(current) + continue + + # Parse run durations + if current: + run = run_pattern.search(line) + if run: + current['runs'].append({'run_number': int(run.group('number')), 'duration_s': float(run.group('duration'))}) + return experiments + + +def write_csv(experiments, out_path): + with open(out_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=['model','profiler_type','run_number','duration_s']) + writer.writeheader() + for exp in experiments: + for run in exp['runs']: + writer.writerow({'model': exp['model'],'profiler_type': exp['profiler_type'],'run_number': run['run_number'],'duration_s': run['duration_s']}) + + +def plot_experiments(experiments): + # Group by model and profiler + models = sorted({exp['model'] for exp in experiments}) + # Calculate baseline (NONE) averages per model + base_avgs = {} + for model in models: + runs = next((exp['runs'] for exp in experiments if exp['model']==model and exp['profiler_type']=='NONE'), []) + base_avgs[model] = (sum(r['duration_s'] for r in runs)/len(runs)) if runs else 0.0 + + # Define desired profiler order, excluding NONE + all_profs = {exp['profiler_type'] for exp in experiments if exp['profiler_type']!='NONE'} + ordered = ['PROTON', 'NSYS', 'TORCH'] + profilers = [p for p in ordered if p in all_profs] + + # Calculate percentage difference from NONE + pct_diff = {prof: [] for prof in profilers} + for model in models: + base = base_avgs.get(model, 0.0) + for prof in profilers: + runs = next((exp['runs'] for exp in experiments if exp['model']==model and exp['profiler_type']==prof), []) + avg = (sum(r['duration_s'] for r in runs)/len(runs)) if runs else 0.0 + percent = ((avg - base) / base) * 100 if base > 0 else 0.0 + pct_diff[prof].append(percent) + + # Plot grouped bar chart + x = np.arange(len(models)) + width = 0.8 / len(profilers) + fig, ax = plt.subplots() + for i, prof in enumerate(profilers): + ax.bar(x + i * width, pct_diff[prof], width, label=prof) + + ax.set_xticks(x + width * (len(profilers) - 1) / 2) + ax.set_xticklabels(models, rotation=45, ha='right') + ax.set_ylabel('Percentage Increase in Training Time (%)') + ax.set_title('Percentage Overhead of Profilers for Unsloth Model Training') + ax.legend() + plt.tight_layout() + plt.savefig('unsloth_profiling.png') + # plt.show() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Extract and optionally plot experiment timings.') + parser.add_argument('log_file', help='Path to the log file') + parser.add_argument('out_csv', help='Output CSV file path') + parser.add_argument('--plot', action='store_true', help='Show grouped bar chart with percent difference') + args = parser.parse_args() + + experiments = parse_experiments(args.log_file) + write_csv(experiments, args.out_csv) + total_runs = sum(len(exp['runs']) for exp in experiments) + print(f"Extracted {total_runs} runs across {len(experiments)} experiments to {args.out_csv}") + + if args.plot: + plot_experiments(experiments) diff --git a/end2end/multimodal/all_timing.sh b/end2end/multimodal/all_timing.sh new file mode 100644 index 0000000..3125a7e --- /dev/null +++ b/end2end/multimodal/all_timing.sh @@ -0,0 +1,19 @@ + +#!/usr/bin/env bash +echo "NO PROFILER" +./profile-wrapper.sh 2 python training.py + +echo "--------------------------------------------" + +echo "NSYS" +./profile-wrapper.sh 2 nsys profile --trace=cuda --sample=none --cpuctxsw=none python training.py + +echo "--------------------------------------------" + +echo "PROTON" +./profile-wrapper.sh 2 proton training.py + +echo "--------------------------------------------" + +echo "TORCH" +./profile-wrapper.sh 2 python training.py --profile_torch diff --git a/end2end/multimodal/profile-wrapper.sh b/end2end/multimodal/profile-wrapper.sh new file mode 100755 index 0000000..e77b41d --- /dev/null +++ b/end2end/multimodal/profile-wrapper.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +set -euo pipefail + +if (( $# < 2 )); then + echo "Usage: $0 " + echo "Example: $0 5 nsys python myscript.py arg1" + exit 1 +fi + +# Number of times to run +num_runs=$1 +shift + +# Array to hold elapsed times +declare -a times + +for ((i=1; i<=num_runs; i++)); do + echo + echo ">>> Run #$i of $num_runs" + + # capture output to a temp log + logfile=$(mktemp /tmp/profile-log.XXXXXX) + + # run the wrapped command + start_marker="" + # start time comes from script’s own START_PROFILE line; end time we record + "$@" 2>&1 | tee "$logfile" + exit_code=${PIPESTATUS[0]} + + # get end timestamp + end_ts=$(date +%s.%N) + + # extract the start timestamp from the first START_PROFILE line + start_line=$(grep -m1 '^START_PROFILE: ' "$logfile" || true) + if [[ -z "$start_line" ]]; then + echo "ERROR: no START_PROFILE found in run #$i" >&2 + rm -f "$logfile" + exit 1 + fi + start_ts=${start_line#START_PROFILE:\ } + + # compute elapsed + elapsed=$(echo "$end_ts - $start_ts" | bc) + + # store + times+=("$elapsed") + + # report this run + printf "Run %2d: %s seconds\n" "$i" "$elapsed" + + rm -f "$logfile" + + # if the wrapped command failed, stop early + if [[ $exit_code -ne 0 ]]; then + echo "Wrapped command exited with code $exit_code. Aborting." + exit $exit_code + fi +done + +# Summary: compute min, max, avg via bc +min=${times[0]} +max=${times[0]} +sum=0 +for t in "${times[@]}"; do + # compare floats: use bc + sleep(5) + is_less=$(echo "$t < $min" | bc) + (( is_less )) && min=$t + is_greater=$(echo "$t > $max" | bc) + (( is_greater )) && max=$t + sum=$(echo "$sum + $t" | bc) +done + +avg=$(echo "$sum / $num_runs" | bc -l) + +echo +echo "=== SUMMARY over $num_runs runs ===" +echo " Min elapsed : $min seconds" +echo " Max elapsed : $max seconds" +echo " Avg elapsed : $avg seconds" +echo "===================================" + +exit 0 + diff --git a/end2end/multimodal/training_multimodal.py b/end2end/multimodal/training_multimodal.py new file mode 100644 index 0000000..17be06e --- /dev/null +++ b/end2end/multimodal/training_multimodal.py @@ -0,0 +1,193 @@ +import argparse +import os + +from dataclasses import dataclass +import time + +import datasets +import torch +import transformers + +from datasets import Image as ImageFeature +from trl import SFTTrainer, SFTConfig +import triton.profiler as proton +from liger_kernel.transformers import monkey_patch + + +@dataclass +class CustomArguments: + model_name: str = "Qwen/Qwen2-VL-2B-Instruct" + dataset: str = "HuggingFaceM4/the_cauldron" + dataset_subset: str = "ai2d" + dataset_split: str = "train" + max_seq_length: int = 512 + dataset_text_field: str = "texts" + use_liger: bool = False + profile_torch: bool = False + + +def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module: + if "Qwen2-VL" in model_name: + from transformers import Qwen2VLForConditionalGeneration + + # These settings are used to reduce the memory footprint of the Qwen2-VL model, + # which supports training/inferences on images in their native resolution. Large + # images -> many visual tokens (a max of 16384) -> large memory consumption. + # If fine-tuning for a real-world application, consider these values carefully. + min_visual_tokens_per_image = 256 + max_visual_tokens_per_image = 256 + + processor = transformers.AutoProcessor.from_pretrained( + model_name, + padding_side="left", + truncation_side="left", + min_pixels=min_visual_tokens_per_image * 28 * 28, # patch size is 14x14 + max_pixels=max_visual_tokens_per_image * 28 * 28, # 4 patches / token + cache_dir="/scratch/jlee436/liger/model", + ) + processor.tokenizer.pad_token = processor.tokenizer.eos_token + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + if use_liger: + print("Applying Liger Kernel to Qwen2-VL model") + monkey_patch.apply_liger_kernel_to_qwen2_vl( + # These args can be used to override the default Liger settings + # cross_entropy=True, + # fused_linear_cross_entropy=False, + ) + + model = Qwen2VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=model_name, + use_cache=False, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + attn_implementation="sdpa", + cache_dir="/scratch/jlee436/liger/model" + ) + return model, processor, image_token_id + + raise NotImplementedError(f"Model {model_name} not supported") + + +def _validate_and_extract_the_cauldron(examples) -> dict[str, list]: + batch_texts = [] + batch_images = [] + for images, texts in zip(examples["images"], examples["texts"]): + if not images: + raise ValueError("No image found in example from the_cauldron dataset") + if len(images) > 1: + raise ValueError("Only one image per example is supported") + batch_texts.extend(texts) + batch_images.extend([images[0]] * len(texts)) + return {"texts": batch_texts, "images": batch_images} + + +def _format_for_convo(example, tokenizer): + # cauldron data is already in message format {"user": ..., "assistant": ...} + text = example["texts"] + messages = [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": text["user"]}], + }, + {"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]}, + ] + text = tokenizer.apply_chat_template(messages, tokenize=False) + # print({"texts": text}) + return {"text": text} + + +def train(): + parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments)) + training_args, custom_args = parser.parse_args_into_dataclasses() + training_args.max_steps = 100 + training_args.remove_unused_columns = False # required to not drop the image column + training_args.dataset_kwargs = {"skip_prepare_dataset": True, } + training_args.dataset_text_field = custom_args.dataset_text_field + training_args.max_seq_length = custom_args.max_seq_length + training_args.logging_dir = "/scratch/jlee436/liger/logs" + training_args.output_dir = "/scratch/jlee436/liger/checkpoints" + + model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger) + + dataset = ( + datasets.load_dataset( + custom_args.dataset, + custom_args.dataset_subset, + split=custom_args.dataset_split, + cache_dir="/scratch/jlee436/liger/data" + ) + .map( + _validate_and_extract_the_cauldron, + batched=True, + num_proc=min(os.cpu_count(), 16), + desc="Extracting text and images", + ) + .map( + _format_for_convo, + fn_kwargs={"tokenizer": processor.tokenizer}, + desc="Formatting for convo", + ) + .cast_column("images", ImageFeature()) + .train_test_split(test_size=0.1) + ) + + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + + def collate_fn(examples): + """ + Taken directly from the TRL documentation with minor modifications: + https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data + + Modifications: + 1. `apply_chat_template` is used to preprocess the texts before training begins (see above) + 2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema + 3. Ignoring image tokens in the loss computation + """ + # Get the texts and images + texts = [example["text"] for example in examples] + images = [example["images"] for example in examples] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + + # Ignore the image token index in the loss computation + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + return batch + + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=processor.tokenizer, + # callbacks=[EfficiencyCallback()], + ) + if custom_args.profile_torch: + import torch.profiler + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + ) as prof: + with torch.profiler.record_function("trainer.train"): + print(f"START_PROFILE: {time.time()}") + trainer.train() + print(f"END_PROFILE: {time.time()}") + else: + with proton.scope("trainer"): + print(f"START_PROFILE: {time.time()}") + trainer.train() + print(f"END_PROFILE: {time.time()}") + + +if __name__ == "__main__": + train() diff --git a/end2end/nano-gpt.sh b/end2end/nano-gpt.sh old mode 100644 new mode 100755 diff --git a/end2end/qwen2_overhead/all_timing.sh b/end2end/qwen2_overhead/all_timing.sh new file mode 100755 index 0000000..3125a7e --- /dev/null +++ b/end2end/qwen2_overhead/all_timing.sh @@ -0,0 +1,19 @@ + +#!/usr/bin/env bash +echo "NO PROFILER" +./profile-wrapper.sh 2 python training.py + +echo "--------------------------------------------" + +echo "NSYS" +./profile-wrapper.sh 2 nsys profile --trace=cuda --sample=none --cpuctxsw=none python training.py + +echo "--------------------------------------------" + +echo "PROTON" +./profile-wrapper.sh 2 proton training.py + +echo "--------------------------------------------" + +echo "TORCH" +./profile-wrapper.sh 2 python training.py --profile_torch diff --git a/end2end/qwen2_overhead/profile-wrapper.sh b/end2end/qwen2_overhead/profile-wrapper.sh new file mode 100755 index 0000000..19d47ad --- /dev/null +++ b/end2end/qwen2_overhead/profile-wrapper.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +set -euo pipefail + +if (( $# < 2 )); then + echo "Usage: $0 " + echo "Example: $0 5 nsys python myscript.py arg1" + exit 1 +fi + +# Number of times to run +num_runs=$1 +shift + +# Array to hold elapsed times +declare -a times + +for ((i=1; i<=num_runs; i++)); do + echo + echo ">>> Run #$i of $num_runs" + + # capture output to a temp log + logfile=$(mktemp /tmp/profile-log.XXXXXX) + + # run the wrapped command + start_marker="" + # start time comes from script’s own START_PROFILE line; end time we record + "$@" 2>&1 | tee "$logfile" + exit_code=${PIPESTATUS[0]} + + # get end timestamp + end_ts=$(date +%s.%N) + + # extract the start timestamp from the first START_PROFILE line + start_line=$(grep -m1 '^START_PROFILE: ' "$logfile" || true) + if [[ -z "$start_line" ]]; then + echo "ERROR: no START_PROFILE found in run #$i" >&2 + rm -f "$logfile" + exit 1 + fi + start_ts=${start_line#START_PROFILE:\ } + + # compute elapsed + elapsed=$(echo "$end_ts - $start_ts" | bc) + + # store + times+=("$elapsed") + + # report this run + printf "Run %2d: %s seconds\n" "$i" "$elapsed" + + rm -f "$logfile" + + # if the wrapped command failed, stop early + if [[ $exit_code -ne 0 ]]; then + echo "Wrapped command exited with code $exit_code. Aborting." + exit $exit_code + fi +done + +# Summary: compute min, max, avg via bc +min=${times[0]} +max=${times[0]} +sum=0 +for t in "${times[@]}"; do + # compare floats: use bc + sleep 5 + is_less=$(echo "$t < $min" | bc) + (( is_less )) && min=$t + is_greater=$(echo "$t > $max" | bc) + (( is_greater )) && max=$t + sum=$(echo "$sum + $t" | bc) +done + +avg=$(echo "$sum / $num_runs" | bc -l) + +echo +echo "=== SUMMARY over $num_runs runs ===" +echo " Min elapsed : $min seconds" +echo " Max elapsed : $max seconds" +echo " Avg elapsed : $avg seconds" +echo "===================================" + +exit 0 + diff --git a/end2end/qwen2_overhead/training.py b/end2end/qwen2_overhead/training.py new file mode 100644 index 0000000..e588237 --- /dev/null +++ b/end2end/qwen2_overhead/training.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass + +import datasets +import torch +from torch import bfloat16 +import transformers + +# from callback import EfficiencyCallback +from trl import DataCollatorForCompletionOnlyLM +from trl import SFTTrainer +import time + +from liger_kernel.transformers import AutoLigerKernelForCausalLM +import triton.profiler as proton +import torch.profiler + +@dataclass +class CustomArguments: + model_name: str = "Qwen/Qwen2-1.5B" + dataset: str = "tatsu-lab/alpaca" + max_seq_length: int = 256 + use_liger: bool = False + profile_torch: bool = False + + +def formatting_prompts_func(example): + return example["text"] + + +def train(): + parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments)) + training_args, custom_args = parser.parse_args_into_dataclasses() + training_args.bf16 = True + training_args.bf16_full_eval = True + training_args.use_liger_kernel = custom_args.use_liger + training_args.max_seq_length = custom_args.max_seq_length + training_args.max_steps = 250 + + tokenizer = transformers.AutoTokenizer.from_pretrained( + custom_args.model_name, + padding_side="left", + truncation_side="left", + cache_dir="/scratch/jlee436/liger/model" + ) + tokenizer.pad_token = tokenizer.eos_token + + dataset = datasets.load_dataset(custom_args.dataset, cache_dir="/scratch/jlee436/liger/data")["train"].train_test_split(test_size=0.1) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False) + collator = DataCollatorForCompletionOnlyLM( + tokenizer=tokenizer, + response_template=response_prompt, + pad_to_multiple_of=4, + ) + + if custom_args.use_liger: + model = AutoLigerKernelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + torch_dtype=bfloat16, + cache_dir="/scratch/jlee436/liger/model", + # These args will get passed to the appropriate apply_liger_kernel_to_* function + # to override the default settings + # cross_entropy=True, + # fused_linear_cross_entropy=False, + ) + else: + model = transformers.AutoModelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + torch_dtype=bfloat16, + cache_dir="/scratch/jlee436/liger/model", + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collator, + # max_seq_length=custom_args.max_seq_length, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + formatting_func=formatting_prompts_func, + # callbacks=[EfficiencyCallback()], + ) + if custom_args.profile_torch: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + ) as prof: + with torch.profiler.record_function("trainer.train"): + print(f"START_PROFILE: {time.time()}") + trainer.train() + print(f"END_PROFILE: {time.time()}") + prof.export_stacks(f"pt_trace.json") + else: + with proton.scope("trainer"): + print(f"START_PROFILE: {time.time()}") + trainer.train() + print(f"END_PROFILE: {time.time()}") + + +if __name__ == "__main__": + train() diff --git a/end2end/unsloth/all_timing.sh b/end2end/unsloth/all_timing.sh new file mode 100755 index 0000000..ca116e7 --- /dev/null +++ b/end2end/unsloth/all_timing.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +model_map=( + "phi-4" + "gemma-3-4b" + "llama-3.1-8b" + "llama-3.2-3b" + "qwen3-14b" + "mistral-7b" + # "llama-4-scout" +) + + +# loop over all models in model_map +for model in "${model_map[@]}"; do + echo "--------------------------------------------" + echo "Timing $model with NONE" + ./profile-wrapper.sh 3 python training.py --model $model --profiling none + sleep 5 + + echo "NSYS" + ./profile-wrapper.sh 3 nsys profile --trace=cuda --sample=none --cpuctxsw=none -o $model.nsys python training.py --model $model --profiling none + sleep 5 + echo "--------------------------------------------" + + echo "PROTON" + ./profile-wrapper.sh 3 proton -n $model training.py --model $model --profiling proton + sleep 5 + echo "--------------------------------------------" + + echo "TORCH" + ./profile-wrapper.sh 3 python training.py --model $model --profiling torch + sleep 5 + echo "--------------------------------------------" +done diff --git a/end2end/unsloth/gemma3/gemma_bench.py b/end2end/unsloth/gemma3/gemma_bench.py new file mode 100644 index 0000000..868e4d2 --- /dev/null +++ b/end2end/unsloth/gemma3/gemma_bench.py @@ -0,0 +1,101 @@ +import os +import torch +from unsloth import FastModel +from datasets import load_dataset +from unsloth.chat_templates import get_chat_template, standardize_data_formats, train_on_responses_only +from trl import SFTTrainer, SFTConfig +import triton.profiler as proton + +def main(profiling_mode): + # Initialize model and tokenizer + print("Initializing model and tokenizer...") + model, tokenizer = FastModel.from_pretrained( + model_name = "unsloth/gemma-3-4b-it", + max_seq_length = 2048, + load_in_4bit = False, + load_in_8bit = False, + full_finetuning = False, + cache_dir = "/scratch/jlee436/unsloth/model" + ) + + # Add LoRA adapters + print("Adding LoRA adapters...") + model = FastModel.get_peft_model( + model, + finetune_vision_layers = False, + finetune_language_layers = True, + finetune_attention_modules = True, + finetune_mlp_modules = True, + r = 8, + lora_alpha = 8, + lora_dropout = 0, + bias = "none", + random_state = 3407, + ) + + # Load and prepare dataset + print("Loading and preparing dataset...") + dataset = load_dataset("mlabonne/FineTome-100k", split="train", cache_dir="/scratch/jlee436/unsloth/data") + dataset = standardize_data_formats(dataset) + + # Apply chat template + print("Applying chat template...") + tokenizer = get_chat_template( + tokenizer, + chat_template = "gemma-3", + ) + + def apply_chat_template(examples): + texts = tokenizer.apply_chat_template(examples["conversations"]) + return {"text": texts} + + dataset = dataset.map(apply_chat_template, batched=True) + + # Initialize trainer + print("Initializing trainer...") + trainer = SFTTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = dataset, + eval_dataset = None, + args = SFTConfig( + dataset_text_field = "text", + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_steps = 5, + max_steps = 30, # Change this for longer training + learning_rate = 2e-4, + logging_steps = 1, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + report_to = "none", + ), + ) + trainer = train_on_responses_only( + trainer, + instruction_part = "user\n", + response_part = "model\n", + ) + # Train on responses only + print("Setting up response-only training...") + if profiling_mode == "proton": + session_id = proton.start(name="profile_name", context="shadow") + trainer_stats = trainer.train() + proton.finalize(session_id) + elif profiling_mode == "torch": + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + trainer_stats = trainer.train() + else: + trainer_stats = trainer.train() + + +if __name__ == "__main__": + # parse arguments to check profiling mode, which is a string "proton", "torch", or "none" + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiling", type=str, default="none") + args = parser.parse_args() + + main(args.profiling) \ No newline at end of file diff --git a/end2end/unsloth/gemma3/nsys_timing.sh b/end2end/unsloth/gemma3/nsys_timing.sh new file mode 100755 index 0000000..03981e5 --- /dev/null +++ b/end2end/unsloth/gemma3/nsys_timing.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Tell bash’s built-in time to output only the real time in seconds +export TIMEFORMAT='%R' + +echo "Running 5 trials; printing elapsed (real) time for each:" +for i in {1..5}; do + # Run the profiling command, suppress its stdout, capture the time output + # elapsed=$( { time python ./proton_test.py > /dev/null; } 2>&1 ) + elapsed=$( { time nsys profile --trace=cuda --sample=none --cpuctxsw=none python ./proton_test.py > /dev/null; } 2>&1 ) + echo "Trial $i: ${elapsed}s" +done diff --git a/end2end/unsloth/gemma3/proton_timing.sh b/end2end/unsloth/gemma3/proton_timing.sh new file mode 100755 index 0000000..89bd32e --- /dev/null +++ b/end2end/unsloth/gemma3/proton_timing.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Tell bash’s built-in time to output only the real time in seconds +export TIMEFORMAT='%R' + +echo "Running 5 trials; printing elapsed (real) time for each:" +for i in {1..5}; do + # Run the profiling command, suppress its stdout, capture the time output + elapsed=$( { time proton ./end2end/gemma3/proton_test.py > /dev/null; } 2>&1 ) + echo "Trial $i: ${elapsed}s" +done + diff --git a/end2end/unsloth/gemma3/pt_timing.sh b/end2end/unsloth/gemma3/pt_timing.sh new file mode 100755 index 0000000..61b6a19 --- /dev/null +++ b/end2end/unsloth/gemma3/pt_timing.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Tell bash’s built-in time to output only the real time in seconds +export TIMEFORMAT='%R' + +echo "Running 5 trials; printing elapsed (real) time for each:" +for i in {0..5}; do + # Run the profiling command, suppress its stdout, capture the time output + elapsed=$( { time python ./proton_test.py --profiler torch > /dev/null; } 2>&1 ) + echo "Trial $i: ${elapsed}s" +done + diff --git a/end2end/unsloth/listing.txt b/end2end/unsloth/listing.txt new file mode 100644 index 0000000..a80a950 --- /dev/null +++ b/end2end/unsloth/listing.txt @@ -0,0 +1,28 @@ +-rw-r--r-- 1 jlee436 users 926023 May 19 01:16 out.txt +-rw-r--r-- 1 jlee436 users 1114198954 May 19 01:15 unsloth_trace_mistral-7b.json +-rw-r--r-- 1 jlee436 users 171632 May 19 01:07 mistral-7b.hatchet +-rw-r--r-- 1 jlee436 users 1420212173 May 19 00:43 unsloth_trace_qwen3-14b.json +-rw-r--r-- 1 jlee436 users 190765 May 19 00:31 qwen3-14b.hatchet +-rw-r--r-- 1 jlee436 users 951312658 May 18 23:57 unsloth_trace_llama-3.2-3b.json +-rw-r--r-- 1 jlee436 users 174033 May 18 23:50 llama-3.2-3b.hatchet +-rw-r--r-- 1 jlee436 users 1105145854 May 18 23:29 unsloth_trace_llama-3.1-8b.json +-rw-r--r-- 1 jlee436 users 164490 May 18 23:20 llama-3.1-8b.hatchet +-rw-r--r-- 1 jlee436 users 1348787048 May 18 22:55 unsloth_trace_gemma-3-4b.json +-rw-r--r-- 1 jlee436 users 271858 May 18 22:44 gemma-3-4b.hatchet +-rw-r--r-- 1 jlee436 users 1389554811 May 18 22:16 unsloth_trace_phi-4.json +-rw-r--r-- 1 jlee436 users 196665 May 18 22:04 phi-4.hatchet +-rw-rw-r-- 1 jlee436 users 59051527 May 18 18:28 mistral-7b.nsys.nsys-rep +-rw-rw-r-- 1 jlee436 users 76175100 May 18 17:46 qwen3-14b.nsys.nsys-rep +-rw-rw-r-- 1 jlee436 users 50375612 May 18 17:12 llama-3.2-3b.nsys.nsys-rep +-rw-rw-r-- 1 jlee436 users 58898031 May 18 16:39 llama-3.1-8b.nsys.nsys-rep +-rw-rw-r-- 1 jlee436 users 70990128 May 18 16:01 gemma-3-4b.nsys.nsys-rep +-rw-rw-r-- 1 jlee436 users 74228119 May 18 15:20 phi-4.nsys.nsys-rep +-rwxr-xr-x 1 jlee436 users 992 May 18 15:06 all_timing.sh +-rw-r--r-- 1 jlee436 users 4903 May 18 13:58 training.py +-rwxr-xr-x 1 jlee436 users 1885 May 18 05:23 profile-wrapper.sh +drwxr-xr-x 2 jlee436 users 4096 May 18 05:14 unsloth_compiled_cache +drwxr-xr-x 2 jlee436 users 4096 May 18 04:48 unsloth_training_checkpoints +drwxr-xr-x 2 jlee436 users 4096 May 18 03:37 phi-4 +drwxr-xr-x 2 jlee436 users 4096 May 18 03:33 qwen3 +drwxr-xr-x 2 jlee436 users 4096 May 18 03:26 gemma3 +drwxr-xr-x 2 jlee436 users 4096 May 12 04:20 llama3_1_4b diff --git a/end2end/unsloth/llama3_1_4b/Gemma3_(4B).ipynb b/end2end/unsloth/llama3_1_4b/Gemma3_(4B).ipynb new file mode 100644 index 0000000..820c1fd --- /dev/null +++ b/end2end/unsloth/llama3_1_4b/Gemma3_(4B).ipynb @@ -0,0 +1,7431 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "NmbKJxjRjeWg" + }, + "source": [ + "To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n", + "
\n", + "\n", + "\n", + " Join Discord if you need help + ⭐ Star us on Github ⭐\n", + "
\n", + "\n", + "To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).\n", + "\n", + "You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-2D281t8jeWg" + }, + "source": [ + "### News" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SSY2XE1xjeWg" + }, + "source": [ + "**Read our [Gemma 3 blog](https://unsloth.ai/blog/gemma3) for what's new in Unsloth and our [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) on how to train reasoning models.**\n", + "\n", + "Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ySsPI_X1jeWh" + }, + "source": [ + "### Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D_L_j6gejeWi" + }, + "outputs": [], + "source": [ + "%%capture\n", + "import os\n", + "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", + " !pip install unsloth vllm\n", + "else:\n", + " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", + " !pip install --no-deps unsloth vllm\n", + "# Install latest Hugging Face for Gemma-3!\n", + "!pip install --no-deps git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tZ3ezvsPjeWi" + }, + "outputs": [], + "source": [ + "#@title Colab Extra Install { display-mode: \"form\" }\n", + "%%capture\n", + "import os\n", + "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", + " !pip install unsloth vllm\n", + "else:\n", + " !pip install --no-deps unsloth vllm\n", + " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", + " # Skip restarting message in Colab\n", + " import sys, re, requests; modules = list(sys.modules.keys())\n", + " for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n", + " !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft \"trl==0.15.2\" triton cut_cross_entropy unsloth_zoo\n", + " !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n", + "\n", + " # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n", + " f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n", + " with open(\"vllm_requirements.txt\", \"wb\") as file:\n", + " file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n", + " !pip install -r vllm_requirements.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TGMWlrRdzwgf" + }, + "source": [ + "### Unsloth\n", + "\n", + "`FastModel` supports loading nearly any model now! This includes Vision and Text models!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 533, + "referenced_widgets": [ + "33815b4c0485402e838127b32ad14a15", + "81c3ee8d97d543ae83db3d2748ecd7cd", + "70f9bccb201a4ea79c1827057cd746f6", + "b32538030b8649e7851bfd58bef2786c", + "0abe1a2b12e54bb5a07db1a8f3a77738", + "e12bca7d49b149b5b1124c28a669db98", + "fab7a02d350946ae9563a05cfd04e22b", + "9be36a19ba86494389904b9fffdc2e48", + "899bf9a4dc594e178d0b95e3cbe08018", + "2fa2784330b2439ab884ce0966037961", + "51649c42cf2145f1b1e90c3436805350", + "eae020595f574192ad9d132853bbf6ec", + "17eaf723882e4efea38119978166fc75", + "409488926d2242c5a8e7b3d5b79c59db", + "26b61507603d453c8c24af24d301bdb9", + "c6c51a350a0b420aab957e3570098e18", + "cebd4fbf1fcf4ab2b65ecf539eda5a1e", + "f14ea72beac74152af5f970e634769ca", + "039f461e15214bd697501219bd9cbbd9", + "9f9d2f43fb5e47df883feb3126fe52e9", + "9e20d704e64a4aaabe5e495c468d9670", + "ae4e4537ed1d4df4b1677700d190a2d2", + "50881898da2f4b35a288ac9befe5024e", + "599375adc5f841d1864166e4d1bd617a", + "3bfc7d7dd81f49a59fc8ad0d6fff858b", + "a0c0025c82394e7fbc6d4cc9a9e9f72f", + "5f0a37b9edc74cbd822e2e71c6c8a956", + "a15e4524521b42108f49dda23ed56023", + "a4eaae1b30d442208257c1870d549738", + "84c68f2059f247628b672b1079130e9b", + "a88fccae11664baa82418c88639f3521", + "1ecb60e1d5934ea19a0d2c29aa01158d", + "f920f6cb263c45d59b851b7c6b631cb5", + "a37b4e454c6743a895f159f963366bc8", + "18fce8679d7d4961b88bb2162e7aa9eb", + "fd60bb21ba2e474cb4a6020bd302835e", + "0c3e3fbf02d84114906e939fcca108b5", + "ab71391446d8482c834429a937f7bb96", + "f9439c3c9b3b4c4a84ed67aa0601a530", + "a89575b4c58348ac98566c22ab7e4118", + "697a187a07204fdfb8556cf5c5028c6b", + "afdb7dfdc17548b39daea1f39d54b45c", + "c4a69698321d435c94211a9dee913c45", + "36c8135711194884be0e03cb5d3ae7e5", + "69a9d565de6a4d54b8f3989fa8b11941", + "090b146b8c6f4fa1ab0261d2a61be9de", + "d2df1617a4094dd295c49a0a72b269c3", + "54d963afc22a46ae93a6ca4bfaa18cf8", + "0b1768a5b5be4a4d9e14ed9e168044b4", + "708d05af00a64d19bb11dc839a5e68db", + "3d20e28d74e549c0a43686b214eebd87", + "dfac71c6372c46bfad46308bdb04480b", + "df50d8daa49a4955914704edb89baf61", + "52f091f99f0442f8bd14a50b7c870c1e", + "c859133fca324effb73ebc3520e746b6", + "f89c08592a25432497bb312f58a13c5c", + "c307400fb17e49ea9d835822e6e22633", + "57ed5097e05f4d92a5c492826f989123", + "db7e622bbd0f4357b5687a7c09c1f6fd", + "8864799f440c440c8ff8c0696e64215d", + "353fa47fb98c4070a150edec64503eaa", + "40bf6f1ffbc5479082fdd7ab153ea974", + "b5a06422fcac41eb97d2de95f14b1806", + "efa41d07d0fa4adda8025fe9490ed850", + "457c60d6a15d4314ba25d370be956a60", + "bc6f29c9a1e14ce8be374867b8be86ac", + "e003ee5cf1804ce3928b544b3fa7ba77", + "dd6e4f9b4c6d4260a62b920a6812fd07", + "4466f20e614a4cdabe1704859c2f1034", + "8949c35d68a043c5b1384774bb07b2ea", + "fca09f95775047efa9d481173f1ba261", + "879c6a0498e54e5c87145c7f7d32de7e", + "740d351b7de241a6acabf6c2853585b6", + "be2b1fb954444be089045a378673a958", + "9520485cc7c24c8584c3838717655012", + "94d9900bc8934de688e95e15c9d0c9bb", + "b469cdf580404f498feb062e4dbad10b", + "5975cf24b18e4082bd80e3f177e0ec15", + "b5b8482ef7c44e12a83795e7337521c2", + "ea9045a5c4504a5e96e6a7b13767fe4e", + "0cba80b626574c11a44c6ce09b5d6e80", + "f344c1ab154b4abdb84b2c221b6162a1", + "0881c055108340f7ab4b840ac1545cbb", + "77907c3444174858bbdc548dee8d0d37", + "a92be3fb752148d887e20afe300f9371", + "aa36bff36ae5448b892afea071fc1f1d", + "12d3049cca4a46c08cf5cdfcd5225248", + "6a9baf0a739c4790baf99b0ccadd6873", + "86c6d49a55b3477bbccc275dcb55fb52", + "77c5f8b431ba4c08b8f4d9d8f9fafc16", + "19c3ef35452d406cb18b72e38b631ee7", + "063db74d47814f95b560bd3bab11b55f", + "5c01ab4767104c0c96c42858317f8877", + "73a283e64f324c27b38216e683050b92", + "81bddfaa180d4875b5cdb5cc4ae45dab", + "77e843913b6e439ead9cf42725eedf3d", + "4f7e8b6b71484cce80ae9cdf0c481825", + "3530e2b431c041c6aeeaca4808ba0424", + "cc14e51320f34274ae12aa28b04183b7", + "f373f2c24f3b413aaa9fe1ccfb9c1eab", + "360142dac8c54a5eb902078ec42abb65", + "a4a64136f6fc48c799abf1701725534e", + "4c5e29de7224428bb87de46854ea915a", + "a3ce4f38be9a456c81146fba440c8e3f", + "70fdd31291b04dd68d66cf31c03d23ff", + "6feaf338d39440e78221649ac84af4a6", + "c9fbf40fa3dd4ba1b363802dc88764da", + "ae7f1fd06ddc4881934690675891855c", + "5e9ba3247edc4fafa7687338424ddccb", + "97d7e3420e24436cb351b1e9679ff8b6" + ] + }, + "id": "-Xbb0cuLzwgf", + "outputId": "3396a9bf-5d9b-45e7-a13e-ac57a78dd441" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", + "🦥 Unsloth Zoo will now patch everything to make training faster!\n", + "==((====))== Unsloth 2025.3.14: Fast Gemma3 patching. Transformers: 4.50.0.dev0.\n", + " \\\\ /| Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.\n", + "O^O/ \\_/ \\ Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0\n", + "\\ / Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]\n", + " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", + "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", + "Unsloth: Using float16 precision for gemma3 won't work! Using float32.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "33815b4c0485402e838127b32ad14a15", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 0%| | 0.00/4.44G [00:00\n", + "### Data Prep\n", + "We now use the `Gemma-3` format for conversation style finetunes. We use [Maxime Labonne's FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset in ShareGPT style. Gemma-3 renders multi turn conversations like below:\n", + "\n", + "```\n", + "user\n", + "Hello!\n", + "model\n", + "Hey there!\n", + "```\n", + "\n", + "We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3, phi4, qwen2.5, gemma3` and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LjY75GoYUCB8" + }, + "outputs": [], + "source": [ + "from unsloth.chat_templates import get_chat_template\n", + "tokenizer = get_chat_template(\n", + " tokenizer,\n", + " chat_template = \"gemma-3\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 113, + "referenced_widgets": [ + "55164fb3b49848beabfe0da534ab4b97", + "79c551f4fd374acdb112c3d8d82c58a0", + "be024d1de576448fb472c3349b2d6d9e", + "173033e49f9f46d9930c6afddc983509", + "2b95af79a12749549da9bd21f0a944db", + "006e1f7432f84b298f61bb597b8aab10", + "f496e27e3d6446c7997c6fc549de69fa", + "1cb8f93e225f427f92c7bf7aa99aba8c", + "89aa2b1f06ab47d7ac0e4cdfd70d7bfe", + "441ff92725ef40ce807daa1ff721faef", + "21136eec72034fcbb47a19f92aa0664d", + "a74d75b5fbe34eb7b1033d21777ecddb", + "3477f57107754fb58bf5080869c81092", + "95371925c6db4abca3d7987775d1e7b0", + "a431edec70ad43569fe75342678ad9d7", + "b2bf96a0eae14ea78c0530c2574a94c7", + "9511c493311b4d4e94d5dc0aca4eeffb", + "f4265abad437426b8efa9091110c77c9", + "be166331dbc5446c8a4f45e861183b71", + "45616dc4ade54cb28ae49c9489d810ca", + "7ff7c3b516d84706a47e91c64b50d171", + "7850eeab711f494399e82bf96f942888", + "2ef58068693547b8b3111d944703d188", + "ec4009400bb34f05bd1554db5bbd9ea8", + "018b82e39081491fb1137b8fc2707f03", + "e0387cc0b0d7426d9e6a7a1cd42bb760", + "3526caa6bc6347aab7a1596a67b7b00e", + "28063728d31f45aea5beffd3f114eec7", + "d86e65b9ea5f47c58b3f2274e4989f93", + "1b703ddf8c7644f3969c999097242473", + "bcbb16a089ed4e78b2b2e0e6eacabc31", + "ce2ae7abaa4841478a732238269aa233", + "5fd55097dac149579c533b2798ebb442" + ] + }, + "id": "Mkq4RvEq7FQr", + "outputId": "3e99d12a-441e-4914-f48d-8184f8a5b9bc" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "55164fb3b49848beabfe0da534ab4b97", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "README.md: 0%| | 0.00/982 [00:00`!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 175 + }, + "id": "gGFzmplrEy9I", + "outputId": "7acc565b-0759-4438-cd0c-fa68d0570863" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'user\\nWhat is the modulus operator in programming and how can I use it to calculate the modulus of two given numbers?\\nmodel\\nIn programming, the modulus operator is represented by the \\'%\\' symbol. It calculates the remainder when one number is divided by another. To calculate the modulus of two given numbers, you can use the modulus operator in the following way:\\n\\n```python\\n# Calculate the modulus\\nModulus = a % b\\n\\nprint(\"Modulus of the given numbers is: \", Modulus)\\n```\\n\\nIn this code snippet, the variables \\'a\\' and \\'b\\' represent the two given numbers for which you want to calculate the modulus. By using the modulus operator \\'%\\', we calculate the remainder when \\'a\\' is divided by \\'b\\'. The result is then stored in the variable \\'Modulus\\'. Finally, the modulus value is printed using the \\'print\\' statement.\\n\\nFor example, if \\'a\\' is 10 and \\'b\\' is 4, the modulus calculation would be 10 % 4, which equals 2. Therefore, the output of the above code would be:\\n\\n```\\nModulus of the given numbers is: 2\\n```\\n\\nThis means that the modulus of 10 and 4 is 2.\\n'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[100][\"text\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "idAEIeSQ3xdS" + }, + "source": [ + "\n", + "### Train the model\n", + "Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 112, + "referenced_widgets": [ + "730aa679b5ac483b929a3646bb5947fa", + "1f333859babd4c1abac69cacda6df864", + "e7f8d2c781a64e83988b0bdd090bdb97", + "3e3feb4fcca74c87abb608c0543236b1", + "2970cbc657d244bab22715bcb788be6a", + "c5b1d1476ddc45249e037df07a96ae37", + "19c27988e01d47e79319f89b5cfd73e2", + "5e48531593d741eaa3669f9118ba8afc", + "62838ffab83d486f86e86417caf0b498", + "7e5378838c114195ba3919fdd683fd7d", + "3350d22f463643ef9a726227f262ac49" + ] + }, + "id": "95_Nn-89DhsL", + "outputId": "2ffe3c70-8c46-41a1-ef51-07c220b5935d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unsloth: Switching to float32 training since model cannot work with float16\n", + "Unsloth: We found double BOS tokens - we shall remove one automatically.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "730aa679b5ac483b929a3646bb5947fa", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Unsloth: Tokenizing [\"text\"] (num_proc=2): 0%| | 0/100000 [00:00user\\n\",\n", + " response_part = \"model\\n\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Dv1NBUozV78l" + }, + "source": [ + "Let's verify masking the instruction part is done! Let's print the 100th row again:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 175 + }, + "id": "LtsMVtlkUhja", + "outputId": "aebb55c2-3883-4494-e9f8-78d60b9b08e8" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'user\\nWhat is the modulus operator in programming and how can I use it to calculate the modulus of two given numbers?\\nmodel\\nIn programming, the modulus operator is represented by the \\'%\\' symbol. It calculates the remainder when one number is divided by another. To calculate the modulus of two given numbers, you can use the modulus operator in the following way:\\n\\n```python\\n# Calculate the modulus\\nModulus = a % b\\n\\nprint(\"Modulus of the given numbers is: \", Modulus)\\n```\\n\\nIn this code snippet, the variables \\'a\\' and \\'b\\' represent the two given numbers for which you want to calculate the modulus. By using the modulus operator \\'%\\', we calculate the remainder when \\'a\\' is divided by \\'b\\'. The result is then stored in the variable \\'Modulus\\'. Finally, the modulus value is printed using the \\'print\\' statement.\\n\\nFor example, if \\'a\\' is 10 and \\'b\\' is 4, the modulus calculation would be 10 % 4, which equals 2. Therefore, the output of the above code would be:\\n\\n```\\nModulus of the given numbers is: 2\\n```\\n\\nThis means that the modulus of 10 and 4 is 2.\\n'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(trainer.train_dataset[100][\"input_ids\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4Kyjy__m9KY3" + }, + "source": [ + "Now let's print the masked out example - you should see only the answer is present:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 175 + }, + "id": "_rD6fl8EUxnG", + "outputId": "e9012c2a-60eb-437e-f145-3e11e9a0dd34" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "' In programming, the modulus operator is represented by the \\'%\\' symbol. It calculates the remainder when one number is divided by another. To calculate the modulus of two given numbers, you can use the modulus operator in the following way:\\n\\n```python\\n# Calculate the modulus\\nModulus = a % b\\n\\nprint(\"Modulus of the given numbers is: \", Modulus)\\n```\\n\\nIn this code snippet, the variables \\'a\\' and \\'b\\' represent the two given numbers for which you want to calculate the modulus. By using the modulus operator \\'%\\', we calculate the remainder when \\'a\\' is divided by \\'b\\'. The result is then stored in the variable \\'Modulus\\'. Finally, the modulus value is printed using the \\'print\\' statement.\\n\\nFor example, if \\'a\\' is 10 and \\'b\\' is 4, the modulus calculation would be 10 % 4, which equals 2. Therefore, the output of the above code would be:\\n\\n```\\nModulus of the given numbers is: 2\\n```\\n\\nThis means that the modulus of 10 and 4 is 2.\\n'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100][\"labels\"]]).replace(tokenizer.pad_token, \" \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2ejIt2xSNKKp", + "outputId": "ba6de9bc-35f1-48ed-8552-5cf1943d0478" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU = Tesla T4. Max memory = 14.741 GB.\n", + "4.283 GB of memory reserved.\n" + ] + } + ], + "source": [ + "# @title Show current memory stats\n", + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CNP1Uidk9mrz" + }, + "source": [ + "Let's train the model! To resume a training run, set `trainer.train(resume_from_checkpoint = True)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "yqxqAZ7KJ4oL", + "outputId": "b44425bc-2ccf-4683-ce72-a837e5a07e9e" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n", + " \\\\ /| Num examples = 100,000 | Num Epochs = 1 | Total steps = 30\n", + "O^O/ \\_/ \\ Batch size per device = 2 | Gradient accumulation steps = 4\n", + "\\ / Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8\n", + " \"-____-\" Trainable parameters = 14,901,248/4,000,000,000 (0.37% trained)\n", + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unsloth: Will smartly offload gradients to save VRAM!\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [30/30 16:18, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.237700
21.636400
31.766300
41.420700
51.235700
61.806600
71.010100
81.896600
91.464700
101.309700
111.461600
121.867400
131.854700
141.394800
151.633000
161.238600
172.174100
181.488000
191.521400
201.816000
211.694700
221.614100
231.829700
241.616300
251.269200
261.291600
271.743000
281.525200
291.999000
301.858200

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pCqnaKmlO1U9", + "outputId": "5d5d33ee-7a84-4418-b038-bd15fb4614e4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1068.4322 seconds used for training.\n", + "17.81 minutes used for training.\n", + "Peak reserved memory = 13.561 GB.\n", + "Peak reserved memory for training = 9.278 GB.\n", + "Peak reserved memory % of max memory = 91.995 %.\n", + "Peak reserved memory for training % of max memory = 62.94 %.\n" + ] + } + ], + "source": [ + "# @title Show final memory and time stats\n", + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(\n", + " f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n", + ")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ekOmTR1hSNcr" + }, + "source": [ + "\n", + "### Inference\n", + "Let's run the model via Unsloth native inference! According to the `Gemma-3` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kR3gIAX-SM2q", + "outputId": "407daa07-ae31-4771-8c31-779665e53bd8" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['user\\nContinue the sequence: 1, 1, 2, 3, 5, 8,\\nmodel\\n13, 21, 34, 55, 89...\\n\\nThis is the Fibonacci sequence, where each number is the sum of the two preceding ones.\\n']" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from unsloth.chat_templates import get_chat_template\n", + "tokenizer = get_chat_template(\n", + " tokenizer,\n", + " chat_template = \"gemma-3\",\n", + ")\n", + "messages = [{\n", + " \"role\": \"user\",\n", + " \"content\": [{\n", + " \"type\" : \"text\",\n", + " \"text\" : \"Continue the sequence: 1, 1, 2, 3, 5, 8,\",\n", + " }]\n", + "}]\n", + "text = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt = True, # Must add for generation\n", + ")\n", + "outputs = model.generate(\n", + " **tokenizer([text], return_tensors = \"pt\").to(\"cuda\"),\n", + " max_new_tokens = 64, # Increase for longer outputs!\n", + " # Recommended Gemma-3 settings!\n", + " temperature = 1.0, top_p = 0.95, top_k = 64,\n", + ")\n", + "tokenizer.batch_decode(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CrSvZObor0lY" + }, + "source": [ + " You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e2pEuRb1r2Vg", + "outputId": "de757d2d-a66b-4be6-c9c9-78cf491dfeba" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Okay, let's break down why the sky is blue! It's a fascinating phenomenon that boils down to a combination of physics and light. Here's the explanation:\n", + "\n", + "**1. Sunlight and its Colors:**\n", + "\n", + "* Sunlight, which appears white to us, is actually made up of *all* the\n" + ] + } + ], + "source": [ + "messages = [{\n", + " \"role\": \"user\",\n", + " \"content\": [{\"type\" : \"text\", \"text\" : \"Why is the sky blue?\",}]\n", + "}]\n", + "text = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt = True, # Must add for generation\n", + ")\n", + "\n", + "from transformers import TextStreamer\n", + "_ = model.generate(\n", + " **tokenizer([text], return_tensors = \"pt\").to(\"cuda\"),\n", + " max_new_tokens = 64, # Increase for longer outputs!\n", + " # Recommended Gemma-3 settings!\n", + " temperature = 1.0, top_p = 0.95, top_k = 64,\n", + " streamer = TextStreamer(tokenizer, skip_prompt = True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uMuVrWbjAzhc" + }, + "source": [ + "\n", + "### Saving, loading finetuned models\n", + "To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.\n", + "\n", + "**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "upcOlWe7A1vc", + "outputId": "a99a1086-5a2d-4828-d599-7e3634a069cd" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['gemma-3/processor_config.json']" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.save_pretrained(\"gemma-3\") # Local saving\n", + "tokenizer.save_pretrained(\"gemma-3\")\n", + "# model.push_to_hub(\"HF_ACCOUNT/gemma-3\", token = \"...\") # Online saving\n", + "# tokenizer.push_to_hub(\"HF_ACCOUNT/gemma-3\", token = \"...\") # Online saving" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AEEcJ4qfC7Lp" + }, + "source": [ + "Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "MKX_XKs_BNZR", + "outputId": "d016d936-4bd5-40f8-dffa-bcfad987f489" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Okay, let's break down what Gemma-3 is. It's a fascinating development in the world of AI, and here's a comprehensive overview:\n", + "\n", + "**1. What it is:**\n", + "\n", + "* **A Family of Open-Weight Language Models:** Gemma-3 isn't just *one* model\n" + ] + } + ], + "source": [ + "if False:\n", + " from unsloth import FastModel\n", + " model, tokenizer = FastModel.from_pretrained(\n", + " model_name = \"lora_model\", # YOUR MODEL YOU USED FOR TRAINING\n", + " max_seq_length = 2048,\n", + " load_in_4bit = True,\n", + " )\n", + "\n", + "messages = [{\n", + " \"role\": \"user\",\n", + " \"content\": [{\"type\" : \"text\", \"text\" : \"What is Gemma-3?\",}]\n", + "}]\n", + "text = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt = True, # Must add for generation\n", + ")\n", + "\n", + "from transformers import TextStreamer\n", + "_ = model.generate(\n", + " **tokenizer([text], return_tensors = \"pt\").to(\"cuda\"),\n", + " max_new_tokens = 64, # Increase for longer outputs!\n", + " # Recommended Gemma-3 settings!\n", + " temperature = 1.0, top_p = 0.95, top_k = 64,\n", + " streamer = TextStreamer(tokenizer, skip_prompt = True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f422JgM9sdVT" + }, + "source": [ + "### Saving to float16 for VLLM\n", + "\n", + "We also support saving to `float16` directly for deployment! We save it in the folder `gemma-3-finetune`. Set `if False` to `if True` to let it run!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iHjt_SMYsd3P" + }, + "outputs": [], + "source": [ + "if False: # Change to True to save finetune!\n", + " model.save_pretrained_merged(\"gemma-3-finetune\", tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z6O48DbNIAr0" + }, + "source": [ + "If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZV-CiKPrIFG0" + }, + "outputs": [], + "source": [ + "if False: # Change to True to upload finetune\n", + " model.push_to_hub_merged(\n", + " \"HF_ACCOUNT/gemma-3-finetune\", tokenizer,\n", + " token = \"hf_...\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TCv4vXHd61i7" + }, + "source": [ + "### GGUF / llama.cpp Conversion\n", + "To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FqfebeAdT073" + }, + "outputs": [], + "source": [ + "if False: # Change to True to save to GGUF\n", + " model.save_pretrained_gguf(\n", + " \"gemma-3-finetune\",\n", + " quantization_type = \"Q8_0\", # For now only Q8_0, BF16, F16 supported\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q974YEVPI7JS" + }, + "source": [ + "Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZgcJIhJ0I_es" + }, + "outputs": [], + "source": [ + "if False: # Change to True to upload GGUF\n", + " model.push_to_hub_gguf(\n", + " \"gemma-3-finetune\",\n", + " quantization_type = \"Q8_0\", # Only Q8_0, BF16, F16 supported\n", + " repo_id = \"HF_ACCOUNT/gemma-finetune-gguf\",\n", + " token = \"hf_...\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CWSZ7ZSvjeWk" + }, + "source": [ + "Now, use the `gemma-3-finetune.gguf` file or `gemma-3-finetune-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)\n", + "\n", + "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n", + "\n", + "Some other links:\n", + "1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n", + "2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n", + "3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n", + "6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!\n", + "\n", + "

\n", + " \n", + " \n", + " \n", + "\n", + " Join Discord if you need help + ⭐️ Star us on Github ⭐️\n", + "
\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "006e1f7432f84b298f61bb597b8aab10": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "018b82e39081491fb1137b8fc2707f03": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1b703ddf8c7644f3969c999097242473", + "max": 100000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_bcbb16a089ed4e78b2b2e0e6eacabc31", + "value": 100000 + } + }, + "039f461e15214bd697501219bd9cbbd9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "063db74d47814f95b560bd3bab11b55f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3530e2b431c041c6aeeaca4808ba0424", + "placeholder": "​", + "style": "IPY_MODEL_cc14e51320f34274ae12aa28b04183b7", + "value": " 35.0/35.0 [00:00<00:00, 2.05kB/s]" + } + }, + "083eaaf245954cc998e8f20e8120eed1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0881c055108340f7ab4b840ac1545cbb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "090b146b8c6f4fa1ab0261d2a61be9de": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_708d05af00a64d19bb11dc839a5e68db", + "placeholder": "​", + "style": "IPY_MODEL_3d20e28d74e549c0a43686b214eebd87", + "value": "preprocessor_config.json: 100%" + } + }, + "0a67ecf349484aa18e1151aa2bc1a9ae": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4b1e445cfec04a0f8566be97d8388310", + "placeholder": "​", + "style": "IPY_MODEL_68f1f454dee4462ebf03789f2df9e0f9", + "value": "Unsloth: Standardizing formats (num_proc=2): 100%" + } + }, + "0abe1a2b12e54bb5a07db1a8f3a77738": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0b1768a5b5be4a4d9e14ed9e168044b4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0c3e3fbf02d84114906e939fcca108b5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c4a69698321d435c94211a9dee913c45", + "placeholder": "​", + "style": "IPY_MODEL_36c8135711194884be0e03cb5d3ae7e5", + "value": " 1.61k/1.61k [00:00<00:00, 181kB/s]" + } + }, + "0cba80b626574c11a44c6ce09b5d6e80": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_12d3049cca4a46c08cf5cdfcd5225248", + "placeholder": "​", + "style": "IPY_MODEL_6a9baf0a739c4790baf99b0ccadd6873", + "value": " 33.4M/33.4M [00:00<00:00, 70.2MB/s]" + } + }, + "0db66084c58047aa84edf3b4fbede0c0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_0a67ecf349484aa18e1151aa2bc1a9ae", + "IPY_MODEL_d2ecda1542b24e4fb859cedc23a8e5ef", + "IPY_MODEL_ea31c3d67a98483ca537bc5917b1aee1" + ], + "layout": "IPY_MODEL_083eaaf245954cc998e8f20e8120eed1" + } + }, + "12d3049cca4a46c08cf5cdfcd5225248": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "148125f955954041a8f5631f9338c43e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_54d2fb7e7c2b4107ba758a7c7ef4f382", + "placeholder": "​", + "style": "IPY_MODEL_34da5010faa749e0940c2821f2f46e59", + "value": " 100000/100000 [00:43<00:00, 2627.58 examples/s]" + } + }, + "173033e49f9f46d9930c6afddc983509": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_441ff92725ef40ce807daa1ff721faef", + "placeholder": "​", + "style": "IPY_MODEL_21136eec72034fcbb47a19f92aa0664d", + "value": " 982/982 [00:00<00:00, 78.7kB/s]" + } + }, + "17eaf723882e4efea38119978166fc75": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cebd4fbf1fcf4ab2b65ecf539eda5a1e", + "placeholder": "​", + "style": "IPY_MODEL_f14ea72beac74152af5f970e634769ca", + "value": "generation_config.json: 100%" + } + }, + "18fce8679d7d4961b88bb2162e7aa9eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f9439c3c9b3b4c4a84ed67aa0601a530", + "placeholder": "​", + "style": "IPY_MODEL_a89575b4c58348ac98566c22ab7e4118", + "value": "chat_template.json: 100%" + } + }, + "19c27988e01d47e79319f89b5cfd73e2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "19c3ef35452d406cb18b72e38b631ee7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_77e843913b6e439ead9cf42725eedf3d", + "max": 35, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4f7e8b6b71484cce80ae9cdf0c481825", + "value": 35 + } + }, + "1b703ddf8c7644f3969c999097242473": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1be74564b60c48d6b21615adf29009fb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ded71beadafd438ebff07bb0594771e4", + "max": 100000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_677e4d7a08ab408e9430d67a2870f707", + "value": 100000 + } + }, + "1cb8f93e225f427f92c7bf7aa99aba8c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1ecb60e1d5934ea19a0d2c29aa01158d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1f333859babd4c1abac69cacda6df864": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c5b1d1476ddc45249e037df07a96ae37", + "placeholder": "​", + "style": "IPY_MODEL_19c27988e01d47e79319f89b5cfd73e2", + "value": "Unsloth: Tokenizing ["text"] (num_proc=2): 100%" + } + }, + "21136eec72034fcbb47a19f92aa0664d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "26b61507603d453c8c24af24d301bdb9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9e20d704e64a4aaabe5e495c468d9670", + "placeholder": "​", + "style": "IPY_MODEL_ae4e4537ed1d4df4b1677700d190a2d2", + "value": " 192/192 [00:00<00:00, 20.1kB/s]" + } + }, + "28063728d31f45aea5beffd3f114eec7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "28de62814a6847e0a0b41ec6bf8fdc66": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2970cbc657d244bab22715bcb788be6a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2b95af79a12749549da9bd21f0a944db": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2ef58068693547b8b3111d944703d188": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ec4009400bb34f05bd1554db5bbd9ea8", + "IPY_MODEL_018b82e39081491fb1137b8fc2707f03", + "IPY_MODEL_e0387cc0b0d7426d9e6a7a1cd42bb760" + ], + "layout": "IPY_MODEL_3526caa6bc6347aab7a1596a67b7b00e" + } + }, + "2fa2784330b2439ab884ce0966037961": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "30918e50f2174d1c8e7af3eef332b6ee": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f4b34bc9a62f405383c3b81dc87792b9", + "IPY_MODEL_1be74564b60c48d6b21615adf29009fb", + "IPY_MODEL_148125f955954041a8f5631f9338c43e" + ], + "layout": "IPY_MODEL_91693f16da7b421885fe8474cf533327" + } + }, + "3350d22f463643ef9a726227f262ac49": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "33815b4c0485402e838127b32ad14a15": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_81c3ee8d97d543ae83db3d2748ecd7cd", + "IPY_MODEL_70f9bccb201a4ea79c1827057cd746f6", + "IPY_MODEL_b32538030b8649e7851bfd58bef2786c" + ], + "layout": "IPY_MODEL_0abe1a2b12e54bb5a07db1a8f3a77738" + } + }, + "3477f57107754fb58bf5080869c81092": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9511c493311b4d4e94d5dc0aca4eeffb", + "placeholder": "​", + "style": "IPY_MODEL_f4265abad437426b8efa9091110c77c9", + "value": "train-00000-of-00001.parquet: 100%" + } + }, + "34da5010faa749e0940c2821f2f46e59": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3526caa6bc6347aab7a1596a67b7b00e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3530e2b431c041c6aeeaca4808ba0424": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "353fa47fb98c4070a150edec64503eaa": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "360142dac8c54a5eb902078ec42abb65": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_70fdd31291b04dd68d66cf31c03d23ff", + "placeholder": "​", + "style": "IPY_MODEL_6feaf338d39440e78221649ac84af4a6", + "value": "special_tokens_map.json: 100%" + } + }, + "36c8135711194884be0e03cb5d3ae7e5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3a24c19f9a6a4e4dba162062a0562db2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3bbc3e42eb0b4a70bb5bf643c4eb0d40": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c0e4c836aa704fd98349a577c2e6ca15", + "max": 100000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_59930ab85ef24940963b2e2c7e578adc", + "value": 100000 + } + }, + "3bfc7d7dd81f49a59fc8ad0d6fff858b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_84c68f2059f247628b672b1079130e9b", + "max": 70, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a88fccae11664baa82418c88639f3521", + "value": 70 + } + }, + "3d20e28d74e549c0a43686b214eebd87": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3e3feb4fcca74c87abb608c0543236b1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7e5378838c114195ba3919fdd683fd7d", + "placeholder": "​", + "style": "IPY_MODEL_3350d22f463643ef9a726227f262ac49", + "value": " 100000/100000 [03:03<00:00, 568.03 examples/s]" + } + }, + "409488926d2242c5a8e7b3d5b79c59db": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_039f461e15214bd697501219bd9cbbd9", + "max": 192, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9f9d2f43fb5e47df883feb3126fe52e9", + "value": 192 + } + }, + "40bf6f1ffbc5479082fdd7ab153ea974": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "441ff92725ef40ce807daa1ff721faef": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4466f20e614a4cdabe1704859c2f1034": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_be2b1fb954444be089045a378673a958", + "max": 4689074, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9520485cc7c24c8584c3838717655012", + "value": 4689074 + } + }, + "45616dc4ade54cb28ae49c9489d810ca": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "457c60d6a15d4314ba25d370be956a60": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "490baae8c2294c2194f32dae183ce833": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4b1e445cfec04a0f8566be97d8388310": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4c5e29de7224428bb87de46854ea915a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5e9ba3247edc4fafa7687338424ddccb", + "placeholder": "​", + "style": "IPY_MODEL_97d7e3420e24436cb351b1e9679ff8b6", + "value": " 670/670 [00:00<00:00, 42.7kB/s]" + } + }, + "4e7cbb77617f43e6ae81e42eb7086fb7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f98456d6dbe143448c31378bf4f101d0", + "IPY_MODEL_3bbc3e42eb0b4a70bb5bf643c4eb0d40", + "IPY_MODEL_923d9b71e3f742da94b11731381eb411" + ], + "layout": "IPY_MODEL_58fa3ae2cc264f7cb988a2258bf8b4e3" + } + }, + "4f7e8b6b71484cce80ae9cdf0c481825": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "50881898da2f4b35a288ac9befe5024e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_599375adc5f841d1864166e4d1bd617a", + "IPY_MODEL_3bfc7d7dd81f49a59fc8ad0d6fff858b", + "IPY_MODEL_a0c0025c82394e7fbc6d4cc9a9e9f72f" + ], + "layout": "IPY_MODEL_5f0a37b9edc74cbd822e2e71c6c8a956" + } + }, + "51649c42cf2145f1b1e90c3436805350": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "529996f48ccf404584725386d33401f9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "52f091f99f0442f8bd14a50b7c870c1e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "54d2fb7e7c2b4107ba758a7c7ef4f382": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "54d963afc22a46ae93a6ca4bfaa18cf8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_52f091f99f0442f8bd14a50b7c870c1e", + "placeholder": "​", + "style": "IPY_MODEL_c859133fca324effb73ebc3520e746b6", + "value": " 570/570 [00:00<00:00, 61.1kB/s]" + } + }, + "55164fb3b49848beabfe0da534ab4b97": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_79c551f4fd374acdb112c3d8d82c58a0", + "IPY_MODEL_be024d1de576448fb472c3349b2d6d9e", + "IPY_MODEL_173033e49f9f46d9930c6afddc983509" + ], + "layout": "IPY_MODEL_2b95af79a12749549da9bd21f0a944db" + } + }, + "57ed5097e05f4d92a5c492826f989123": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b5a06422fcac41eb97d2de95f14b1806", + "max": 1157008, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_efa41d07d0fa4adda8025fe9490ed850", + "value": 1157008 + } + }, + "58fa3ae2cc264f7cb988a2258bf8b4e3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5975cf24b18e4082bd80e3f177e0ec15": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b5b8482ef7c44e12a83795e7337521c2", + "IPY_MODEL_ea9045a5c4504a5e96e6a7b13767fe4e", + "IPY_MODEL_0cba80b626574c11a44c6ce09b5d6e80" + ], + "layout": "IPY_MODEL_f344c1ab154b4abdb84b2c221b6162a1" + } + }, + "59930ab85ef24940963b2e2c7e578adc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "599375adc5f841d1864166e4d1bd617a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a15e4524521b42108f49dda23ed56023", + "placeholder": "​", + "style": "IPY_MODEL_a4eaae1b30d442208257c1870d549738", + "value": "processor_config.json: 100%" + } + }, + "5c01ab4767104c0c96c42858317f8877": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5e48531593d741eaa3669f9118ba8afc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5e9ba3247edc4fafa7687338424ddccb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5f0a37b9edc74cbd822e2e71c6c8a956": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5fd55097dac149579c533b2798ebb442": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "62838ffab83d486f86e86417caf0b498": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "677e4d7a08ab408e9430d67a2870f707": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "68f1f454dee4462ebf03789f2df9e0f9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "697a187a07204fdfb8556cf5c5028c6b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "69a9d565de6a4d54b8f3989fa8b11941": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_090b146b8c6f4fa1ab0261d2a61be9de", + "IPY_MODEL_d2df1617a4094dd295c49a0a72b269c3", + "IPY_MODEL_54d963afc22a46ae93a6ca4bfaa18cf8" + ], + "layout": "IPY_MODEL_0b1768a5b5be4a4d9e14ed9e168044b4" + } + }, + "6a9baf0a739c4790baf99b0ccadd6873": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6feaf338d39440e78221649ac84af4a6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "708d05af00a64d19bb11dc839a5e68db": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "70f9bccb201a4ea79c1827057cd746f6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9be36a19ba86494389904b9fffdc2e48", + "max": 4437712931, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_899bf9a4dc594e178d0b95e3cbe08018", + "value": 4437712508 + } + }, + "70fdd31291b04dd68d66cf31c03d23ff": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "730aa679b5ac483b929a3646bb5947fa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1f333859babd4c1abac69cacda6df864", + "IPY_MODEL_e7f8d2c781a64e83988b0bdd090bdb97", + "IPY_MODEL_3e3feb4fcca74c87abb608c0543236b1" + ], + "layout": "IPY_MODEL_2970cbc657d244bab22715bcb788be6a" + } + }, + "73a283e64f324c27b38216e683050b92": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "740d351b7de241a6acabf6c2853585b6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "77907c3444174858bbdc548dee8d0d37": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "77c5f8b431ba4c08b8f4d9d8f9fafc16": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_73a283e64f324c27b38216e683050b92", + "placeholder": "​", + "style": "IPY_MODEL_81bddfaa180d4875b5cdb5cc4ae45dab", + "value": "added_tokens.json: 100%" + } + }, + "77e843913b6e439ead9cf42725eedf3d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7850eeab711f494399e82bf96f942888": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "79c551f4fd374acdb112c3d8d82c58a0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_006e1f7432f84b298f61bb597b8aab10", + "placeholder": "​", + "style": "IPY_MODEL_f496e27e3d6446c7997c6fc549de69fa", + "value": "README.md: 100%" + } + }, + "7e5378838c114195ba3919fdd683fd7d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7ff7c3b516d84706a47e91c64b50d171": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "81bddfaa180d4875b5cdb5cc4ae45dab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "81c3ee8d97d543ae83db3d2748ecd7cd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e12bca7d49b149b5b1124c28a669db98", + "placeholder": "​", + "style": "IPY_MODEL_fab7a02d350946ae9563a05cfd04e22b", + "value": "model.safetensors: 100%" + } + }, + "84c68f2059f247628b672b1079130e9b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "86c6d49a55b3477bbccc275dcb55fb52": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_77c5f8b431ba4c08b8f4d9d8f9fafc16", + "IPY_MODEL_19c3ef35452d406cb18b72e38b631ee7", + "IPY_MODEL_063db74d47814f95b560bd3bab11b55f" + ], + "layout": "IPY_MODEL_5c01ab4767104c0c96c42858317f8877" + } + }, + "879c6a0498e54e5c87145c7f7d32de7e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8864799f440c440c8ff8c0696e64215d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8949c35d68a043c5b1384774bb07b2ea": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_94d9900bc8934de688e95e15c9d0c9bb", + "placeholder": "​", + "style": "IPY_MODEL_b469cdf580404f498feb062e4dbad10b", + "value": " 4.69M/4.69M [00:00<00:00, 22.1MB/s]" + } + }, + "899bf9a4dc594e178d0b95e3cbe08018": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "89aa2b1f06ab47d7ac0e4cdfd70d7bfe": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "91693f16da7b421885fe8474cf533327": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "91f1a912832140089e20d35a01de6c74": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "923d9b71e3f742da94b11731381eb411": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ebecfe57cabc46028198957fb45114cc", + "placeholder": "​", + "style": "IPY_MODEL_ba9001d2d8db4f75abb557fc2306bf5b", + "value": " 100000/100000 [00:17<00:00, 6582.77 examples/s]" + } + }, + "94d9900bc8934de688e95e15c9d0c9bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9511c493311b4d4e94d5dc0aca4eeffb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9520485cc7c24c8584c3838717655012": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "95371925c6db4abca3d7987775d1e7b0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_be166331dbc5446c8a4f45e861183b71", + "max": 116531415, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_45616dc4ade54cb28ae49c9489d810ca", + "value": 116531404 + } + }, + "97d7e3420e24436cb351b1e9679ff8b6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "99685b8ae00c40d5b950cbf44b15efa1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "9be36a19ba86494389904b9fffdc2e48": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9e20d704e64a4aaabe5e495c468d9670": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9f9d2f43fb5e47df883feb3126fe52e9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a0c0025c82394e7fbc6d4cc9a9e9f72f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1ecb60e1d5934ea19a0d2c29aa01158d", + "placeholder": "​", + "style": "IPY_MODEL_f920f6cb263c45d59b851b7c6b631cb5", + "value": " 70.0/70.0 [00:00<00:00, 8.21kB/s]" + } + }, + "a0cf0f8dd74a4d5181d465fa7ff52915": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a15e4524521b42108f49dda23ed56023": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a37b4e454c6743a895f159f963366bc8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_18fce8679d7d4961b88bb2162e7aa9eb", + "IPY_MODEL_fd60bb21ba2e474cb4a6020bd302835e", + "IPY_MODEL_0c3e3fbf02d84114906e939fcca108b5" + ], + "layout": "IPY_MODEL_ab71391446d8482c834429a937f7bb96" + } + }, + "a3ce4f38be9a456c81146fba440c8e3f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a431edec70ad43569fe75342678ad9d7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7ff7c3b516d84706a47e91c64b50d171", + "placeholder": "​", + "style": "IPY_MODEL_7850eeab711f494399e82bf96f942888", + "value": " 117M/117M [00:00<00:00, 213MB/s]" + } + }, + "a4a64136f6fc48c799abf1701725534e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c9fbf40fa3dd4ba1b363802dc88764da", + "max": 670, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_ae7f1fd06ddc4881934690675891855c", + "value": 670 + } + }, + "a4eaae1b30d442208257c1870d549738": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a74d75b5fbe34eb7b1033d21777ecddb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3477f57107754fb58bf5080869c81092", + "IPY_MODEL_95371925c6db4abca3d7987775d1e7b0", + "IPY_MODEL_a431edec70ad43569fe75342678ad9d7" + ], + "layout": "IPY_MODEL_b2bf96a0eae14ea78c0530c2574a94c7" + } + }, + "a88fccae11664baa82418c88639f3521": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a89575b4c58348ac98566c22ab7e4118": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a92be3fb752148d887e20afe300f9371": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aa36bff36ae5448b892afea071fc1f1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "ab71391446d8482c834429a937f7bb96": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ae4e4537ed1d4df4b1677700d190a2d2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ae7f1fd06ddc4881934690675891855c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "afdb7dfdc17548b39daea1f39d54b45c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b2bf96a0eae14ea78c0530c2574a94c7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b32538030b8649e7851bfd58bef2786c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2fa2784330b2439ab884ce0966037961", + "placeholder": "​", + "style": "IPY_MODEL_51649c42cf2145f1b1e90c3436805350", + "value": " 4.44G/4.44G [00:31<00:00, 142MB/s]" + } + }, + "b469cdf580404f498feb062e4dbad10b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b5a06422fcac41eb97d2de95f14b1806": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b5b8482ef7c44e12a83795e7337521c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0881c055108340f7ab4b840ac1545cbb", + "placeholder": "​", + "style": "IPY_MODEL_77907c3444174858bbdc548dee8d0d37", + "value": "tokenizer.json: 100%" + } + }, + "ba9001d2d8db4f75abb557fc2306bf5b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "bc6f29c9a1e14ce8be374867b8be86ac": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "bcbb16a089ed4e78b2b2e0e6eacabc31": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "be024d1de576448fb472c3349b2d6d9e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1cb8f93e225f427f92c7bf7aa99aba8c", + "max": 982, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_89aa2b1f06ab47d7ac0e4cdfd70d7bfe", + "value": 982 + } + }, + "be166331dbc5446c8a4f45e861183b71": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be2b1fb954444be089045a378673a958": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c0e4c836aa704fd98349a577c2e6ca15": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c307400fb17e49ea9d835822e6e22633": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_353fa47fb98c4070a150edec64503eaa", + "placeholder": "​", + "style": "IPY_MODEL_40bf6f1ffbc5479082fdd7ab153ea974", + "value": "tokenizer_config.json: 100%" + } + }, + "c43cef665c9542f982986a74dc50ca98": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c4a69698321d435c94211a9dee913c45": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c5b1d1476ddc45249e037df07a96ae37": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c6c51a350a0b420aab957e3570098e18": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c859133fca324effb73ebc3520e746b6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c9fbf40fa3dd4ba1b363802dc88764da": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cc14e51320f34274ae12aa28b04183b7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ce2ae7abaa4841478a732238269aa233": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cebd4fbf1fcf4ab2b65ecf539eda5a1e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d2df1617a4094dd295c49a0a72b269c3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dfac71c6372c46bfad46308bdb04480b", + "max": 570, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_df50d8daa49a4955914704edb89baf61", + "value": 570 + } + }, + "d2ecda1542b24e4fb859cedc23a8e5ef": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_91f1a912832140089e20d35a01de6c74", + "max": 100000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_99685b8ae00c40d5b950cbf44b15efa1", + "value": 100000 + } + }, + "d86e65b9ea5f47c58b3f2274e4989f93": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "db7e622bbd0f4357b5687a7c09c1f6fd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_457c60d6a15d4314ba25d370be956a60", + "placeholder": "​", + "style": "IPY_MODEL_bc6f29c9a1e14ce8be374867b8be86ac", + "value": " 1.16M/1.16M [00:00<00:00, 10.7MB/s]" + } + }, + "dd6e4f9b4c6d4260a62b920a6812fd07": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_879c6a0498e54e5c87145c7f7d32de7e", + "placeholder": "​", + "style": "IPY_MODEL_740d351b7de241a6acabf6c2853585b6", + "value": "tokenizer.model: 100%" + } + }, + "ded71beadafd438ebff07bb0594771e4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "df50d8daa49a4955914704edb89baf61": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "dfac71c6372c46bfad46308bdb04480b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e003ee5cf1804ce3928b544b3fa7ba77": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_dd6e4f9b4c6d4260a62b920a6812fd07", + "IPY_MODEL_4466f20e614a4cdabe1704859c2f1034", + "IPY_MODEL_8949c35d68a043c5b1384774bb07b2ea" + ], + "layout": "IPY_MODEL_fca09f95775047efa9d481173f1ba261" + } + }, + "e0387cc0b0d7426d9e6a7a1cd42bb760": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ce2ae7abaa4841478a732238269aa233", + "placeholder": "​", + "style": "IPY_MODEL_5fd55097dac149579c533b2798ebb442", + "value": " 100000/100000 [00:02<00:00, 60707.08 examples/s]" + } + }, + "e12bca7d49b149b5b1124c28a669db98": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e7f8d2c781a64e83988b0bdd090bdb97": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5e48531593d741eaa3669f9118ba8afc", + "max": 100000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_62838ffab83d486f86e86417caf0b498", + "value": 100000 + } + }, + "ea31c3d67a98483ca537bc5917b1aee1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_490baae8c2294c2194f32dae183ce833", + "placeholder": "​", + "style": "IPY_MODEL_3a24c19f9a6a4e4dba162062a0562db2", + "value": " 100000/100000 [00:07<00:00, 13546.66 examples/s]" + } + }, + "ea9045a5c4504a5e96e6a7b13767fe4e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a92be3fb752148d887e20afe300f9371", + "max": 33384568, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_aa36bff36ae5448b892afea071fc1f1d", + "value": 33384568 + } + }, + "eae020595f574192ad9d132853bbf6ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_17eaf723882e4efea38119978166fc75", + "IPY_MODEL_409488926d2242c5a8e7b3d5b79c59db", + "IPY_MODEL_26b61507603d453c8c24af24d301bdb9" + ], + "layout": "IPY_MODEL_c6c51a350a0b420aab957e3570098e18" + } + }, + "ebecfe57cabc46028198957fb45114cc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ec4009400bb34f05bd1554db5bbd9ea8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_28063728d31f45aea5beffd3f114eec7", + "placeholder": "​", + "style": "IPY_MODEL_d86e65b9ea5f47c58b3f2274e4989f93", + "value": "Generating train split: 100%" + } + }, + "efa41d07d0fa4adda8025fe9490ed850": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f14ea72beac74152af5f970e634769ca": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f344c1ab154b4abdb84b2c221b6162a1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f373f2c24f3b413aaa9fe1ccfb9c1eab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_360142dac8c54a5eb902078ec42abb65", + "IPY_MODEL_a4a64136f6fc48c799abf1701725534e", + "IPY_MODEL_4c5e29de7224428bb87de46854ea915a" + ], + "layout": "IPY_MODEL_a3ce4f38be9a456c81146fba440c8e3f" + } + }, + "f4265abad437426b8efa9091110c77c9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f496e27e3d6446c7997c6fc549de69fa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f4b34bc9a62f405383c3b81dc87792b9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_28de62814a6847e0a0b41ec6bf8fdc66", + "placeholder": "​", + "style": "IPY_MODEL_c43cef665c9542f982986a74dc50ca98", + "value": "Map (num_proc=2): 100%" + } + }, + "f89c08592a25432497bb312f58a13c5c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c307400fb17e49ea9d835822e6e22633", + "IPY_MODEL_57ed5097e05f4d92a5c492826f989123", + "IPY_MODEL_db7e622bbd0f4357b5687a7c09c1f6fd" + ], + "layout": "IPY_MODEL_8864799f440c440c8ff8c0696e64215d" + } + }, + "f920f6cb263c45d59b851b7c6b631cb5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f9439c3c9b3b4c4a84ed67aa0601a530": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f98456d6dbe143448c31378bf4f101d0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a0cf0f8dd74a4d5181d465fa7ff52915", + "placeholder": "​", + "style": "IPY_MODEL_529996f48ccf404584725386d33401f9", + "value": "Map: 100%" + } + }, + "fab7a02d350946ae9563a05cfd04e22b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fca09f95775047efa9d481173f1ba261": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fd60bb21ba2e474cb4a6020bd302835e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_697a187a07204fdfb8556cf5c5028c6b", + "max": 1615, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_afdb7dfdc17548b39daea1f39d54b45c", + "value": 1615 + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/end2end/unsloth/llama3_1_4b/llama3_1.py b/end2end/unsloth/llama3_1_4b/llama3_1.py new file mode 100644 index 0000000..c714956 --- /dev/null +++ b/end2end/unsloth/llama3_1_4b/llama3_1.py @@ -0,0 +1,197 @@ +# Original code from https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb +import argparse +import torch + + +from datasets import load_dataset +import triton.profiler as proton + +SUPPORTS_BFLOAT16 = torch.cuda.get_device_capability()[0] >= 8 + + +def format_dataset(dataset, tokenizer): + + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {} + + ### Input: + {} + + ### Response: + {}""" + + EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN + def formatting_prompts_func(examples): + instructions = examples["instruction"] + inputs = examples["input"] + outputs = examples["output"] + texts = [] + for instruction, input, output in zip(instructions, inputs, outputs): + # Must add EOS_TOKEN, otherwise your generation will go on forever! + text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN + texts.append(text) + return { "text" : texts, } + pass + + return dataset.map(formatting_prompts_func, batched = True,) + + + +def train_unsloth(args): + from unsloth import FastLanguageModel + from trl import SFTTrainer + from transformers import TrainingArguments + + max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! + dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ + load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False. + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name = "unsloth/Meta-Llama-3.1-8B", + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + cache_dir = "/scratch/jlee436/unsloth/model", + # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf + ) + + dataset = load_dataset("yahma/alpaca-cleaned", split = "train", cache_dir="/scratch/jlee436/unsloth/data") + dataset = format_dataset(dataset, tokenizer) + + + model = FastLanguageModel.get_peft_model( + model, + r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj",], + lora_alpha = 16, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized + # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = False, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ + ) + + + trainer = SFTTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = dataset, + dataset_text_field = "text", + max_seq_length = max_seq_length, + dataset_num_proc = 2, + packing = False, # Can make training 5x faster for short sequences. + args = TrainingArguments( + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_steps = 5, + # num_train_epochs = 1, # Set this for 1 full training run. + max_steps = 60, + learning_rate = 2e-4, + fp16 = not SUPPORTS_BFLOAT16, + bf16 = SUPPORTS_BFLOAT16, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "outputs", + report_to = "none", # Use this for WandB etc + ), + ) + + if args.profiling == "proton": + session_id = proton.start(name="profile_name", context="shadow") + trainer_stats = trainer.train() + proton.finalize(session_id) + elif args.profiling == "torch": + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + trainer_stats = trainer.train() + else: + trainer_stats = trainer.train() + + +def train_native(): + from trl import SFTTrainer + from transformers import TrainingArguments + from transformers import AutoModelForCausalLM, AutoTokenizer + from peft import get_peft_model + from peft import LoraConfig, TaskType + + + + + max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! + dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ + load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False. + + model, tokenizer = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path = "unsloth/Meta-Llama-3.1-8B", + # max_seq_length = max_seq_length, + # dtype = dtype, + # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf + ), AutoTokenizer.from_pretrained("unsloth/Meta-Llama-3.1-8B") + + lora_config = LoraConfig( + r=16, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj",], + task_type=TaskType.CAUSAL_LM, + lora_alpha=16, + lora_dropout=0, + use_rslora=False, + # loftq_config=None, + bias="none", + ) + + # lora_model = get_peft_model(model, lora_config) + + dataset = load_dataset("yahma/alpaca-cleaned", split = "train") + dataset = format_dataset(dataset, tokenizer) + + + + trainer = SFTTrainer( + model = model, + # tokenizer = tokenizer, + train_dataset = dataset, + # dataset_text_field = "text", + # max_length = max_seq_length, + # dataset_num_proc = 2, + # packing = False, + peft_config=lora_config, + args = TrainingArguments( + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_steps = 5, + # num_train_epochs = 1, # Set this for 1 full training run. + max_steps = 60, + learning_rate = 2e-4, + fp16 = not SUPPORTS_BFLOAT16, + bf16 = SUPPORTS_BFLOAT16, + logging_steps = 1, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "outputs", + report_to = "none", # Use this for WandB etc + ), + ) + + +if __name__ == "__main__": + # parse args + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="unsloth") + parser.add_argument("--profiling", type=str, default="none") + args = parser.parse_args() + + # use args.model to determine which model to train + if args.model == "unsloth": + train_unsloth(args) + elif args.model == "native": + train_native() diff --git a/end2end/unsloth/llama3_1_4b/nsys_timing.sh b/end2end/unsloth/llama3_1_4b/nsys_timing.sh new file mode 100755 index 0000000..b1b19fc --- /dev/null +++ b/end2end/unsloth/llama3_1_4b/nsys_timing.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Tell bash’s built-in time to output only the real time in seconds +export TIMEFORMAT='%R' +ALL_ELAPSED=() +echo "Running 5 trials; printing elapsed (real) time for each:" +for i in {1..5}; do + # Run the profiling command, suppress its stdout, capture the time output + elapsed=$( { time nsys profile --trace=cuda --sample=none --cpuctxsw=none python ./llama3_1.py > /dev/null; } 2>&1 ) + # elapsed=$( { time python ./llama3_1.py > /dev/null; } 2>&1 ) + echo "Trial $i: ${elapsed}s" + ALL_ELAPSED+=($elapsed) +done + +# Print all elapsed times +echo "All elapsed times: ${ALL_ELAPSED[@]}" + +# Calculate and print the average elapsed time + diff --git a/end2end/unsloth/llama3_1_4b/proton_timing.sh b/end2end/unsloth/llama3_1_4b/proton_timing.sh new file mode 100755 index 0000000..139b91c --- /dev/null +++ b/end2end/unsloth/llama3_1_4b/proton_timing.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Tell bash’s built-in time to output only the real time in seconds +export TIMEFORMAT='%R' +ALL_ELAPSED=() +echo "Running 5 trials; printing elapsed (real) time for each:" +for i in {1..5}; do + # Run the profiling command, suppress its stdout, capture the time output + elapsed=$( { time proton ./llama3_1.py > /dev/null; } 2>&1 ) + echo "Trial $i: ${elapsed}s" + ALL_ELAPSED+=($elapsed) +done + +# Print all elapsed times +echo "All elapsed times: ${ALL_ELAPSED[@]}" + +# Calculate and print the average elapsed time + diff --git a/end2end/unsloth/llama3_1_4b/pt_timing.sh b/end2end/unsloth/llama3_1_4b/pt_timing.sh new file mode 100755 index 0000000..a6fd6fd --- /dev/null +++ b/end2end/unsloth/llama3_1_4b/pt_timing.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Tell bash’s built-in time to output only the real time in seconds +export TIMEFORMAT='%R' +ALL_ELAPSED=() +echo "Running 5 trials; printing elapsed (real) time for each:" +for i in {1..5}; do + # Run the profiling command, suppress its stdout, capture the time output + elapsed=$( { time python ./llama3_1.py --profiling torch > /dev/null; } 2>&1 ) + echo "Trial $i: ${elapsed}" + ALL_ELAPSED+=($elapsed) +done + +# Print all elapsed times +echo "All elapsed times: ${ALL_ELAPSED[@]}" + +# Calculate and print the average elapsed time + diff --git a/end2end/unsloth/llama3_1_4b/requirements.txt b/end2end/unsloth/llama3_1_4b/requirements.txt new file mode 100644 index 0000000..996eb24 --- /dev/null +++ b/end2end/unsloth/llama3_1_4b/requirements.txt @@ -0,0 +1,2 @@ +unsloth +peft \ No newline at end of file diff --git a/end2end/unsloth/memsizes.png b/end2end/unsloth/memsizes.png new file mode 100644 index 0000000..64e7e11 Binary files /dev/null and b/end2end/unsloth/memsizes.png differ diff --git a/end2end/unsloth/out.txt b/end2end/unsloth/out.txt new file mode 100644 index 0000000..917be60 --- /dev/null +++ b/end2end/unsloth/out.txt @@ -0,0 +1,2676 @@ +-------------------------------------------- +Timing phi-4 with NONE + +>>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for phi-4... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for gemma-3-4b... +==((====))== Unsloth 2025.5.3: Fast Gemma3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/2 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.1-8b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/4 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for llama-3.2-3b... +==((====))== Unsloth 2025.5.3: Fast Llama patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. +Adding LoRA adapters for llama-3.2-3b... +Unsloth: Making `model.base_model.model.model` require gradients +Loading and preparing dataset... +Initializing trainer... + Unsloth: Tokenizing ["text"] (num_proc=72): 0%| | 0/1000 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for qwen3-14b... +==((====))== Unsloth 2025.5.3: Fast Qwen3 patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/6 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #1 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #2 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00>> Run #3 of 3 +🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. +🦥 Unsloth Zoo will now patch everything to make training faster! +Initializing model and tokenizer for mistral-7b... +==((====))== Unsloth 2025.5.3: Fast Mistral patching. Transformers: 4.51.3. + \\ /| NVIDIA GH200 480GB. Num GPUs = 1. Max memory: 95.0 GB. Platform: Linux. +O^O/ \_/ \ Torch: 2.6.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.2.0 +\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+1298453.d20250514. FA2 = False] + "-____-" Free license: http://github.com/unslothai/unsloth +Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored! +Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA. + Loading checkpoint shards: 0%| | 0/3 [00:00 PROTON + unsloth_trace_*.json -> TORCH + .nsys.nsys-rep -> NSYS + """ + patterns = [ + ('PROTON', re.compile(r'^(?P[\w\.\-]+)\.hatchet$')), + ('TORCH', re.compile(r'^unsloth_trace_(?P[\w\.\-]+)\.json$')), + ('NSYS', re.compile(r'^(?P[\w\.\-]+)\.nsys\.nsys-rep$')) + ] + records = [] + + with open(listing_path, 'r', encoding='utf-8') as f: + for line in f: + parts = line.split() + if len(parts) < 9: + continue + size = int(parts[4]) + filename = parts[8] + for profiler, pat in patterns: + m = pat.match(filename) + if m: + model = m.group('model') + records.append({'model': model, 'profiler_type': profiler, 'size_bytes': size}) + break + return records + + +def write_csv(records, out_path): + """Write parsed size records to CSV.""" + with open(out_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=['model', 'profiler_type', 'size_bytes']) + writer.writeheader() + for rec in records: + writer.writerow(rec) + + +def plot_sizes(records): + """Plot grouped bar chart of percent size difference vs PROTON baseline.""" + # Unique models + models = sorted({r['model'] for r in records}) + # Baseline sizes (PROTON) + baseline = {m: next((r['size_bytes'] for r in records if r['model']==m and r['profiler_type']=='PROTON'), 0) + for m in models} + # Ordered profilers excluding baseline + ordered = ['PROTON', 'NSYS', 'TORCH'] + profilers = [p for p in ordered if p!='PROTON' and any(r['profiler_type']==p for r in records)] + + # Compute percent differences + pct = {p: [] for p in profilers} + for m in models: + base = baseline[m] + for p in profilers: + size = next((r['size_bytes'] for r in records if r['model']==m and r['profiler_type']==p), 0) + pct[p].append(((size - base) / base) if base > 0 else 0) + + # Plot + x = np.arange(len(models)) + width = 0.8 / len(profilers) + fig, ax = plt.subplots() + for i, p in enumerate(profilers): + ax.bar(x + i*width, pct[p], width, label=p) + ax.set_xticks(x + width*(len(profilers)-1)/2) + ax.set_xticklabels(models, rotation=45, ha='right') + ax.set_ylabel('Relative Profile Size Difference in Comparison to Proton') + ax.set_title('Profile File Size Comparison by Model') + ax.legend() + plt.tight_layout() + plt.savefig("memsizes.png") + plt.show() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Extract and compare profile file sizes.') + parser.add_argument('listing_file', help='Path to ls -l listing file') + parser.add_argument('out_csv', help='Output CSV path') + parser.add_argument('--plot', action='store_true', help='Show grouped percent-difference bar chart') + args = parser.parse_args() + + records = parse_sizes(args.listing_file) + write_csv(records, args.out_csv) + print(f"Extracted {len(records)} profile size records to {args.out_csv}") + if args.plot: + plot_sizes(records) + diff --git a/end2end/unsloth/phi-4/training.py b/end2end/unsloth/phi-4/training.py new file mode 100644 index 0000000..4e70752 --- /dev/null +++ b/end2end/unsloth/phi-4/training.py @@ -0,0 +1,101 @@ +import os +import torch +from unsloth import FastModel +from datasets import load_dataset +from unsloth.chat_templates import get_chat_template, standardize_data_formats, train_on_responses_only +from trl import SFTTrainer, SFTConfig +import triton.profiler as proton + +def main(profiling_mode): + # Initialize model and tokenizer + print("Initializing model and tokenizer...") + model, tokenizer = FastModel.from_pretrained( + model_name = "unsloth/Phi-4", + max_seq_length = 2048, + load_in_4bit = False, + load_in_8bit = False, + full_finetuning = False, + cache_dir = "/scratch/jlee436/unsloth/model" + ) + + # Add LoRA adapters + print("Adding LoRA adapters...") + model = FastModel.get_peft_model( + model, + finetune_vision_layers = False, + finetune_language_layers = True, + finetune_attention_modules = True, + finetune_mlp_modules = True, + r = 8, + lora_alpha = 8, + lora_dropout = 0, + bias = "none", + random_state = 3407, + ) + + # Load and prepare dataset + print("Loading and preparing dataset...") + dataset = load_dataset("mlabonne/FineTome-100k", split="train", cache_dir="/scratch/jlee436/unsloth/data") + dataset = standardize_data_formats(dataset) + + # Apply chat template + print("Applying chat template...") + tokenizer = get_chat_template( + tokenizer, + chat_template = "gemma-3", + ) + + def apply_chat_template(examples): + texts = tokenizer.apply_chat_template(examples["conversations"]) + return {"text": texts} + + dataset = dataset.map(apply_chat_template, batched=True) + + # Initialize trainer + print("Initializing trainer...") + trainer = SFTTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = dataset, + eval_dataset = None, + args = SFTConfig( + dataset_text_field = "text", + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_steps = 5, + max_steps = 30, # Change this for longer training + learning_rate = 2e-4, + logging_steps = 1, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + report_to = "none", + ), + ) + trainer = train_on_responses_only( + trainer, + instruction_part = "user\n", + response_part = "model\n", + ) + # Train on responses only + print("Setting up response-only training...") + if profiling_mode == "proton": + session_id = proton.start(name="profile_name", context="shadow") + trainer_stats = trainer.train() + proton.finalize(session_id) + elif profiling_mode == "torch": + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + trainer_stats = trainer.train() + else: + trainer_stats = trainer.train() + + +if __name__ == "__main__": + # parse arguments to check profiling mode, which is a string "proton", "torch", or "none" + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiling", type=str, default="none") + args = parser.parse_args() + + main(args.profiling) \ No newline at end of file diff --git a/end2end/unsloth/profile-wrapper.sh b/end2end/unsloth/profile-wrapper.sh new file mode 100755 index 0000000..213a5aa --- /dev/null +++ b/end2end/unsloth/profile-wrapper.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +set -euo pipefail + +if (( $# < 2 )); then + echo "Usage: $0 " + echo "Example: $0 5 nsys python myscript.py arg1" + exit 1 +fi + +# Number of times to run +num_runs=$1 +shift + +# Array to hold elapsed times +declare -a times + +for ((i=1; i<=num_runs; i++)); do + echo + echo ">>> Run #$i of $num_runs" + sleep 5 + + # capture output to a temp log + logfile=$(mktemp /tmp/profile-log.XXXXXX) + + # run the wrapped command + start_marker="" + # start time comes from script’s own START_PROFILE line; end time we record + "$@" 2>&1 | tee "$logfile" + exit_code=${PIPESTATUS[0]} + + # get end timestamp + end_ts=$(date +%s.%N) + + # extract the start timestamp from the first START_PROFILE line + start_line=$(grep -m1 '^START_PROFILE: ' "$logfile" || true) + if [[ -z "$start_line" ]]; then + echo "ERROR: no START_PROFILE found in run #$i" >&2 + rm -f "$logfile" + exit 1 + fi + start_ts=${start_line#START_PROFILE:\ } + + # compute elapsed + elapsed=$(echo "$end_ts - $start_ts" | bc) + + # store + times+=("$elapsed") + + # report this run + printf "Run %2d: %s seconds\n" "$i" "$elapsed" + + rm -f "$logfile" + + # if the wrapped command failed, stop early + if [[ $exit_code -ne 0 ]]; then + echo "Wrapped command exited with code $exit_code. Aborting." + exit $exit_code + fi +done + +# Summary: compute min, max, avg via bc +min=${times[0]} +max=${times[0]} +sum=0 +for t in "${times[@]}"; do + # compare floats: use bc + sleep 5 + is_less=$(echo "$t < $min" | bc) + (( is_less )) && min=$t + is_greater=$(echo "$t > $max" | bc) + (( is_greater )) && max=$t + sum=$(echo "$sum + $t" | bc) +done + +avg=$(echo "$sum / $num_runs" | bc -l) + +echo +echo "=== SUMMARY over $num_runs runs ===" +echo " Min elapsed : $min seconds" +echo " Max elapsed : $max seconds" +echo " Avg elapsed : $avg seconds" +echo "===================================" + +exit 0 + diff --git a/end2end/unsloth/qwen3/training.py b/end2end/unsloth/qwen3/training.py new file mode 100644 index 0000000..debda69 --- /dev/null +++ b/end2end/unsloth/qwen3/training.py @@ -0,0 +1,101 @@ +import os +import torch +from unsloth import FastModel +from datasets import load_dataset +from unsloth.chat_templates import get_chat_template, standardize_data_formats, train_on_responses_only +from trl import SFTTrainer, SFTConfig +import triton.profiler as proton + +def main(profiling_mode): + # Initialize model and tokenizer + print("Initializing model and tokenizer...") + model, tokenizer = FastModel.from_pretrained( + model_name = "unsloth/Qwen3-14B", + max_seq_length = 2048, + load_in_4bit = False, + load_in_8bit = False, + full_finetuning = False, + cache_dir = "/scratch/jlee436/unsloth/model" + ) + + # Add LoRA adapters + print("Adding LoRA adapters...") + model = FastModel.get_peft_model( + model, + finetune_vision_layers = False, + finetune_language_layers = True, + finetune_attention_modules = True, + finetune_mlp_modules = True, + r = 8, + lora_alpha = 8, + lora_dropout = 0, + bias = "none", + random_state = 3407, + ) + + # Load and prepare dataset + print("Loading and preparing dataset...") + dataset = load_dataset("mlabonne/FineTome-100k", split="train", cache_dir="/scratch/jlee436/unsloth/data") + dataset = standardize_data_formats(dataset) + + # Apply chat template + print("Applying chat template...") + tokenizer = get_chat_template( + tokenizer, + chat_template = "gemma-3", + ) + + def apply_chat_template(examples): + texts = tokenizer.apply_chat_template(examples["conversations"]) + return {"text": texts} + + dataset = dataset.map(apply_chat_template, batched=True) + + # Initialize trainer + print("Initializing trainer...") + trainer = SFTTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = dataset, + eval_dataset = None, + args = SFTConfig( + dataset_text_field = "text", + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_steps = 5, + max_steps = 30, # Change this for longer training + learning_rate = 2e-4, + logging_steps = 1, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + report_to = "none", + ), + ) + trainer = train_on_responses_only( + trainer, + instruction_part = "user\n", + response_part = "model\n", + ) + # Train on responses only + print("Setting up response-only training...") + if profiling_mode == "proton": + session_id = proton.start(name="profile_name", context="shadow") + trainer_stats = trainer.train() + proton.finalize(session_id) + elif profiling_mode == "torch": + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + trainer_stats = trainer.train() + else: + trainer_stats = trainer.train() + + +if __name__ == "__main__": + # parse arguments to check profiling mode, which is a string "proton", "torch", or "none" + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiling", type=str, default="none") + args = parser.parse_args() + + main(args.profiling) \ No newline at end of file diff --git a/end2end/unsloth/sizes.csv b/end2end/unsloth/sizes.csv new file mode 100644 index 0000000..e5f97e0 --- /dev/null +++ b/end2end/unsloth/sizes.csv @@ -0,0 +1,19 @@ +model,profiler_type,size_bytes +mistral-7b,TORCH,1114198954 +mistral-7b,PROTON,171632 +qwen3-14b,TORCH,1420212173 +qwen3-14b,PROTON,190765 +llama-3.2-3b,TORCH,951312658 +llama-3.2-3b,PROTON,174033 +llama-3.1-8b,TORCH,1105145854 +llama-3.1-8b,PROTON,164490 +gemma-3-4b,TORCH,1348787048 +gemma-3-4b,PROTON,271858 +phi-4,TORCH,1389554811 +phi-4,PROTON,196665 +mistral-7b,NSYS,59051527 +qwen3-14b,NSYS,76175100 +llama-3.2-3b,NSYS,50375612 +llama-3.1-8b,NSYS,58898031 +gemma-3-4b,NSYS,70990128 +phi-4,NSYS,74228119 diff --git a/end2end/unsloth/training.py b/end2end/unsloth/training.py new file mode 100644 index 0000000..2714145 --- /dev/null +++ b/end2end/unsloth/training.py @@ -0,0 +1,152 @@ +import os +import time +import torch +from unsloth import FastModel +from datasets import load_dataset +from unsloth.chat_templates import get_chat_template, standardize_data_formats, train_on_responses_only +from unsloth.chat_templates import standardize_sharegpt +import pandas as pd +from trl import SFTTrainer, SFTConfig +from datasets import Dataset +import triton.profiler as proton + +model_map = { + "phi-4": "unsloth/Phi-4", + "gemma-3-4b": "unsloth/gemma-3-4b-it", + "llama-3.1-8b": "unsloth/Llama-3.1-8B", + "llama-3.2-3b": "unsloth/Llama-3.2-3B", + "qwen3-14b": "unsloth/Qwen3-14B", + "mistral-7b": "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", + "llama-4-scout": "unsloth/Llama-4-Scout-17B-16E-unsloth-bnb-4bit" +} +def format_dataset_llama(dataset, tokenizer): + + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {} + + ### Input: + {} + + ### Response: + {}""" + + EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN + def formatting_prompts_func(examples): + instructions = examples["instruction"] + inputs = examples["input"] + outputs = examples["output"] + texts = [] + for instruction, input, output in zip(instructions, inputs, outputs): + # Must add EOS_TOKEN, otherwise your generation will go on forever! + text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN + texts.append(text) + return { "text" : texts, } + pass + + return dataset.map(formatting_prompts_func, batched = True,) + +def main(profiling_mode, model_name): + # Initialize model and tokenizer + print(f"Initializing model and tokenizer for {model_name}...") + model, tokenizer = FastModel.from_pretrained( + model_name = model_map[model_name], + max_seq_length = 2048, + load_in_4bit = False, + load_in_8bit = False, + full_finetuning = False, + cache_dir = "/scratch/jlee436/unsloth/model" + ) + + # Add LoRA adapters + print(f"Adding LoRA adapters for {model_name}...") + model = FastModel.get_peft_model( + model, + finetune_vision_layers = False, + finetune_language_layers = True, + finetune_attention_modules = True, + finetune_mlp_modules = True, + r = 8, + lora_alpha = 8, + lora_dropout = 0, + bias = "none", + random_state = 3407, + ) + + # Load and prepare dataset + print("Loading and preparing dataset...") + + + + dataset = load_dataset("mlabonne/FineTome-100k", split="train", cache_dir="/scratch/jlee436/unsloth/data") + dataset = standardize_sharegpt(dataset, tokenizer) + + if "llama" in model_name.lower(): + tokenizer = get_chat_template( + tokenizer, + chat_template = "llama-3.1", + ) + elif "mistral" in model_name.lower(): + tokenizer = get_chat_template( + tokenizer, + chat_template = "mistral", + ) + + + dataset = tokenizer.apply_chat_template( + dataset["conversations"], + tokenize = False, + ) + + data = pd.Series(dataset)[:1000] + data.name = "text" + final_dataset = Dataset.from_pandas(pd.DataFrame(data)) + + # Initialize trainer + print("Initializing trainer...") + trainer = SFTTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = final_dataset, + eval_dataset = None, + args = SFTConfig( + dataset_text_field = "text", + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_steps = 5, + max_steps = 30, # Change this for longer training + learning_rate = 2e-4, + logging_steps = 15, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + report_to = "none", + ), + ) + + print(f"START_PROFILE: {time.time()}") + if profiling_mode == "proton": + session_id = proton.start(name=f"unsloth_{model_name}", context="shadow") + trainer.train() + proton.finalize(session_id) + elif profiling_mode == "torch": + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + trainer.train() + prof.export_chrome_trace(f"unsloth_trace_{model_name}.json") + else: + trainer.train() + print(f"END_PROFILE: {time.time()}") + + + +if __name__ == "__main__": + # parse arguments to check profiling mode, which is a string "proton", "torch", or "none" + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiling", type=str, default="none") + parser.add_argument("--model", type=str, default="phi-4") + args = parser.parse_args() + + main(args.profiling, args.model) diff --git a/kernel-microbench/01-vector-add.py b/kernel-microbench/01-vector-add.py new file mode 100644 index 0000000..258a891 --- /dev/null +++ b/kernel-microbench/01-vector-add.py @@ -0,0 +1,111 @@ +""" +Vector Addition +=============== + +In this tutorial, you will write a simple vector addition using Triton. + +In doing so, you will learn about: + +* The basic programming model of Triton. + +* The `triton.jit` decorator, which is used to define Triton kernels. + +* The best practices for validating and benchmarking your custom ops against native reference implementations. + +""" + +import argparse +import torch + +import triton +import triton.language as tl +import triton.profiler as proton + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + output = torch.empty_like(x) + assert x.is_cuda and y.is_cuda and output.is_cuda + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output + + +def benchmark(): + sizes = [2**i for i in range(12, 28, 1)] + for size in sizes: + with proton.scope(f"size_{size}"): + for _ in range(500): + x = torch.rand(size, device='cuda', dtype=torch.float32) + y = torch.rand(size, device='cuda', dtype=torch.float32) + with proton.scope("triton"): + add(x, y) + with proton.scope("torch"): + torch.add(x, y) + + +def main(): + args = argparse.ArgumentParser() + # just check if flag is provided + args.add_argument("--profiler", type=str, default="") + args.add_argument("--pc_sampling", action="store_true") + args = args.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark"): + benchmark() + + with open("vector_add_torch.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + backend = "cupti_pcsampling" if args.pc_sampling else "cupti" + proton.start(name="proton_add", context="shadow", backend=backend) + benchmark() + proton.finalize() + else: + print("Profiling with nsys") + benchmark() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/kernel-microbench/01-vector-add_results.csv b/kernel-microbench/01-vector-add_results.csv new file mode 100644 index 0000000..6942a72 --- /dev/null +++ b/kernel-microbench/01-vector-add_results.csv @@ -0,0 +1,6 @@ +run,script,baseline,proton,torch,nsys +1,01-vector-add.py,7.68,8.15,7.86,8.78 +2,01-vector-add.py,7.44,8.01,8.24,8.35 +3,01-vector-add.py,7.32,7.75,7.83,8.33 +4,01-vector-add.py,7.60,7.86,8.05,8.36 +5,01-vector-add.py,7.53,8.18,7.59,8.14 diff --git a/kernel-microbench/02-fused-softmax_results.csv b/kernel-microbench/02-fused-softmax_results.csv new file mode 100644 index 0000000..4f570f3 --- /dev/null +++ b/kernel-microbench/02-fused-softmax_results.csv @@ -0,0 +1,6 @@ +run,script,baseline,proton,torch,nsys +1,02-fused-softmax.py,9.39,9.69,9.78,10.17 +2,02-fused-softmax.py,9.26,9.64,10.15,10.40 +3,02-fused-softmax.py,9.81,10.13,10.05,10.42 +4,02-fused-softmax.py,9.32,9.66,10.07,10.24 +5,02-fused-softmax.py,9.26,9.67,9.88,10.44 diff --git a/kernel-microbench/02-softmax.py b/kernel-microbench/02-softmax.py new file mode 100644 index 0000000..a602021 --- /dev/null +++ b/kernel-microbench/02-softmax.py @@ -0,0 +1,229 @@ +""" +Fused Softmax +============= + +In this tutorial, you will write a fused softmax operation that is significantly faster +than PyTorch's native op for a particular class of matrices: those whose rows can fit in +the GPU's SRAM. + +In doing so, you will learn about: + +* The benefits of kernel fusion for bandwidth-bound operations. + +* Reduction operators in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. +# Let us consider instead the case of a simple (numerically stabilized) softmax operation: + +import torch + +import triton +import triton.language as tl +from triton.runtime import driver +import triton.profiler as proton + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', + 'gfx90a', 'gfx908') + + +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +# %% +# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` +# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. +# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads +# X once and does all the necessary computations on-chip. +# Doing so would require reading and writing back only :math:`MN` bytes, so we could +# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). +# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically +# but, as we will see later, it is still far from ideal. + +# %% +# Compute Kernel +# -------------- +# +# Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, +# normalizes it and writes back the result to the output Y. +# +# Note that one important limitation of Triton is that each block must have a +# power-of-two number of elements, so we need to internally "pad" each row and guard the +# memory operations properly if we want to handle any possible input shapes: + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +# %% +# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. + +device = torch.cuda.current_device() +properties = driver.active.utils.get_device_properties(device) +NUM_SM = properties["multiprocessor_count"] +NUM_REGS = properties["max_num_regs"] +SIZE_SMEM = properties["max_shared_mem"] +WARP_SIZE = properties["warpSize"] +target = triton.runtime.driver.active.get_current_target() +kernels = {} + + +def softmax(x): + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 8 + + # Number of software pipelining stages. + num_stages = 4 if SIZE_SMEM > 200000 else 2 + + # Allocate output + y = torch.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + num_stages=num_stages, num_warps=num_warps, grid=(1, )) + kernel._init_handles() + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + occupancy = min(occupancy, SIZE_SMEM // size_smem) + num_programs = NUM_SM * occupancy + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + ) + return y + + +def benchmark(): + M = 4096 + N_vals = [2**i for i in range(8, 16)] + providers = ['triton', 'torch'] + for N in N_vals: + with proton.scope(f"N_{N}"): + for _ in range(1000): + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + for provider in providers: + with proton.scope(provider): + if provider == 'torch': + # continue + torch.softmax(x, axis=-1) + # pass + elif provider == 'triton': + pass + # softmax(x) + + +def main(): + import argparse + args = argparse.ArgumentParser() + args.add_argument("--profiler", type=str, default="") + args.add_argument("--pc_sampling", action="store_true") + args = args.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark"): + benchmark() + elif args.profiler == "proton": + print("Profiling with proton") + backend = "cupti_pcsampling" if args.pc_sampling else "cupti" + proton.start(name="proton_softmax", context="shadow", backend=backend) + benchmark() + proton.finalize() + else: + print("Profiling with nsys") + benchmark() + + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/03-matrix-multiplication.py b/kernel-microbench/03-matrix-multiplication.py new file mode 100644 index 0000000..e07b29e --- /dev/null +++ b/kernel-microbench/03-matrix-multiplication.py @@ -0,0 +1,403 @@ +""" +Matrix Multiplication +===================== +In this tutorial, you will write a very short high-performance FP16 matrix multiplication kernel that achieves +performance on par with cuBLAS or rocBLAS. + +You will specifically learn about: + +* Block-level matrix multiplications. + +* Multi-dimensional pointer arithmetic. + +* Program re-ordering for improved L2 cache hit rate. + +* Automatic performance tuning. + +""" + +# %% +# Motivations +# ----------- +# +# Matrix multiplications are a key building block of most modern high-performance computing systems. +# They are notoriously hard to optimize, hence their implementation is generally done by +# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +# Unfortunately, these libraries are often proprietary and cannot be easily customized +# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). +# In this tutorial, you will learn how to implement efficient matrix multiplications by +# yourself with Triton, in a way that is easy to customize and extend. +# +# Roughly speaking, the kernel that we will write will implement the following blocked +# algorithm to multiply a (M, K) by a (K, N) matrix: +# +# .. code-block:: python +# +# # Do in parallel +# for m in range(0, M, BLOCK_SIZE_M): +# # Do in parallel +# for n in range(0, N, BLOCK_SIZE_N): +# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] +# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] +# acc += dot(a, b) +# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc +# +# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. + +# %% +# Compute Kernel +# -------------- +# +# The above algorithm is, actually, fairly straightforward to implement in Triton. +# The main difficulty comes from the computation of the memory locations at which blocks +# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need +# multi-dimensional pointer arithmetic. +# +# Pointer Arithmetic +# ~~~~~~~~~~~~~~~~~~~ +# +# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given +# by :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`. +# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and +# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: +# +# .. code-block:: python +# +# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); +# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); +# +# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as the following +# code. Also note that we need an extra modulo to handle the case where :code:`M` is not a multiple of +# :code:`BLOCK_SIZE_M` or :code:`N` is not a multiple of :code:`BLOCK_SIZE_N`, in which case we can pad the data with +# some useless values, which will not contribute to the results. For the :code:`K` dimension, we will handle that later +# using masking load semantics. +# +# .. code-block:: python +# +# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M +# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) +# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) +# +# And then updated in the inner loop as follows: +# +# .. code-block:: python +# +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; +# +# +# L2 Cache Optimizations +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]` +# block of :code:`C`. +# It is important to remember that the order in which these blocks are computed does +# matter, since it affects the L2 cache hit rate of our program, and unfortunately, a +# simple row-major ordering +# +# .. code-block:: Python +# +# pid = tl.program_id(axis=0) +# grid_n = tl.cdiv(N, BLOCK_SIZE_N) +# pid_m = pid // grid_n +# pid_n = pid % grid_n +# +# is just not going to cut it. +# +# One possible solution is to launch blocks in an order that promotes data reuse. +# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_SIZE_M` rows before +# switching to the next column: +# +# .. code-block:: python +# +# # Program ID +# pid = tl.program_id(axis=0) +# # Number of program ids along the M axis +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# # Number of programs ids along the N axis +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# # Number of programs in group +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# # Id of the group this program is in +# group_id = pid // num_pid_in_group +# # Row-id of the first program in the group +# first_pid_m = group_id * GROUP_SIZE_M +# # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# # *Within groups*, programs are ordered in a column-major order +# # Row-id of the program in the *launch grid* +# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) +# # Col-id of the program in the *launch grid* +# pid_n = (pid % num_pid_in_group) // group_size_m +# +# For example, in the following matmul where each matrix is 9 blocks by 9 blocks, +# we can see that if we compute the output in row-major ordering, we need to load 90 +# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped +# ordering, we only need to load 54 blocks. +# +# .. image:: grouped_vs_row_major_ordering.png +# +# In practice, this can improve the performance of our matrix multiplication kernel by +# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). +# + +# %% +# Final Result +# ------------ + +import torch + +import triton +import triton.language as tl +import argparse +import triton.profiler as proton + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def get_cuda_autotune_config(): + return [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + # Good config for fp8 inputs. + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4) + ] + + +def get_hip_autotune_config(): + return [ + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=4, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=8, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, + num_warps=4, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, + num_warps=4, num_stages=2), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation # + ) + return c + + +def simple_benchmark(): + # Use the same configs as in the perf_report for argument ranges + TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") + ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' + sizes = [128 * i for i in range(2, 6)] + for fp8_inputs in [False, True]: + if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): + continue + providers = ["triton"] if fp8_inputs else [ref_lib.lower(), "triton"] + for provider in providers: + for size in sizes: + M = N = K = size + for _ in range(20): # Run a few iterations for each config + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + # if TORCH_HAS_FP8 and fp8_inputs: + # a = a.to(torch.float8_e5m2) + # b = b.T + # b = b.to(torch.float8_e5m2) + if provider == ref_lib.lower(): + torch.matmul(a, b) + elif provider == 'triton': + matmul(a, b) + + +def main(): + args = argparse.ArgumentParser() + args.add_argument("--profiler", type=str, default="") + args.add_argument("--pc_sampling", action="store_true") + args = args.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark"): + simple_benchmark() + with open("matmul_torch.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + + elif args.profiler == "proton": + print("Profiling with proton") + backend = "cupti_pcsampling" if args.pc_sampling else "cupti" + proton.start(name="proton_matmul", context="shadow", backend=backend) + simple_benchmark() + proton.finalize() + else: + print("Profiling with nsys") + simple_benchmark() + + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/03-matrix-multiplication_results.csv b/kernel-microbench/03-matrix-multiplication_results.csv new file mode 100644 index 0000000..77f094d --- /dev/null +++ b/kernel-microbench/03-matrix-multiplication_results.csv @@ -0,0 +1,6 @@ +run,script,baseline,proton,torch,nsys +1,03-matrix-multiplication.py,12.99,12.52,13.24,14.19 +2,03-matrix-multiplication.py,12.79,12.27,13.66,14.76 +3,03-matrix-multiplication.py,12.52,12.36,13.89,14.50 +4,03-matrix-multiplication.py,12.82,12.90,13.54,14.57 +5,03-matrix-multiplication.py,12.95,12.24,13.40,15.12 diff --git a/kernel-microbench/04-low-memory-dropout.py b/kernel-microbench/04-low-memory-dropout.py new file mode 100644 index 0000000..63bbab9 --- /dev/null +++ b/kernel-microbench/04-low-memory-dropout.py @@ -0,0 +1,131 @@ +""" +Low-Memory Dropout +================== + +In this tutorial, you will write a memory-efficient implementation of dropout whose state +will be composed of a single int32 seed. This differs from more traditional implementations of dropout, +whose state is generally composed of a bit mask tensor of the same shape as the input. + +In doing so, you will learn about: + +* The limitations of naive implementations of Dropout with PyTorch. + +* Parallel pseudo-random number generation in Triton. + +""" + + +import torch + +import triton +import triton.language as tl +import argparse +import torch.profiler +import triton.profiler as proton + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + + + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +def benchmark(): + sizes = [2**i for i in range(12, 20, 1)] + p = 0.5 + for size in sizes: + with torch.cuda.device(0): + for _ in range(500): + x = torch.randn(size, device='cuda', dtype=torch.float32) + # Baseline dropout + x_keep = (torch.rand(size, device='cuda') > p).to(torch.int32) + dropout(x, x_keep, p) + # Seeded dropout + seed = torch.randint(0, 2**31, ()).item() + seeded_dropout(x, p, seed) + + +def main(): + + args = argparse.ArgumentParser() + args.add_argument("--profiler", type=str, default="") + args.add_argument("--pc_sampling", action="store_true") + args = args.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark"): + benchmark() + with open("dropout_torch.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + backend = "cupti_pcsampling" if args.pc_sampling else "cupti" + proton.start(name="proton_dropout", context="shadow", backend=backend) + benchmark() + proton.finalize() + else: + print("Profiling with nsys") + benchmark() + + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/04-low-memory-dropout_results.csv b/kernel-microbench/04-low-memory-dropout_results.csv new file mode 100644 index 0000000..3b869e6 --- /dev/null +++ b/kernel-microbench/04-low-memory-dropout_results.csv @@ -0,0 +1,6 @@ +run,script,baseline,proton,torch,nsys +1,04-low-memory-dropout.py,8.21,8.93,7.91,8.54 +2,04-low-memory-dropout.py,8.50,8.09,8.54,8.98 +3,04-low-memory-dropout.py,7.61,8.45,8.26,8.52 +4,04-low-memory-dropout.py,8.07,8.44,8.22,8.32 +5,04-low-memory-dropout.py,7.71,8.23,8.10,8.56 diff --git a/kernel-microbench/05-layer-norm.py b/kernel-microbench/05-layer-norm.py new file mode 100644 index 0000000..e8e9658 --- /dev/null +++ b/kernel-microbench/05-layer-norm.py @@ -0,0 +1,439 @@ +""" +Layer Normalization +==================== +In this tutorial, you will write a high-performance layer normalization +kernel that runs faster than the PyTorch implementation. + +In doing so, you will learn about: + +* Implementing backward pass in Triton. + +* Implementing parallel reduction in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# The *LayerNorm* operator was first introduced in [BA2016]_ as a way to improve the performance +# of sequential models (e.g., Transformers) or neural networks with small batch size. +# It takes a vector :math:`x` as input and produces a vector :math:`y` of the same shape as output. +# The normalization is performed by subtracting the mean and dividing by the standard deviation of :math:`x`. +# After the normalization, a learnable linear transformation with weights :math:`w` and biases :math:`b` is applied. +# The forward pass can be expressed as follows: +# +# .. math:: +# y = \frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} } * w + b +# +# where :math:`\epsilon` is a small constant added to the denominator for numerical stability. +# Let's first take a look at the forward pass implementation. + +import torch + +import triton +import triton.language as tl +import triton.profiler as proton + +try: + # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it + # should not be added to extras_require in setup.py. + import apex + HAS_APEX = True +except ModuleNotFoundError: + HAS_APEX = False + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +# %% +# Backward pass +# ------------- +# +# The backward pass for the layer normalization operator is a bit more involved than the forward pass. +# Let :math:`\hat{x}` be the normalized inputs :math:`\frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} }` before the linear transformation, +# the Vector-Jacobian Products (VJP) :math:`\nabla_{x}` of :math:`x` are given by: +# +# .. math:: +# \nabla_{x} = \frac{1}{\sigma}\Big( \nabla_{y} \odot w - \underbrace{ \big( \frac{1}{N} \hat{x} \cdot (\nabla_{y} \odot w) \big) }_{c_1} \odot \hat{x} - \underbrace{ \frac{1}{N} \nabla_{y} \cdot w }_{c_2} \Big) +# +# where :math:`\odot` denotes the element-wise multiplication, :math:`\cdot` denotes the dot product, and :math:`\sigma` is the standard deviation. +# :math:`c_1` and :math:`c_2` are intermediate constants that improve the readability of the following implementation. +# +# For the weights :math:`w` and biases :math:`b`, the VJPs :math:`\nabla_{w}` and :math:`\nabla_{b}` are more straightforward: +# +# .. math:: +# \nabla_{w} = \nabla_{y} \odot \hat{x} \quad \text{and} \quad \nabla_{b} = \nabla_{y} +# +# Since the same weights :math:`w` and biases :math:`b` are used for all rows in the same batch, their gradients need to sum up. +# To perform this step efficiently, we use a parallel reduction strategy: each kernel instance accumulates +# partial :math:`\nabla_{w}` and :math:`\nabla_{b}` across certain rows into one of :math:`\text{GROUP_SIZE_M}` independent buffers. +# These buffers stay in the L2 cache and then are further reduced by another function to compute the actual :math:`\nabla_{w}` and :math:`\nabla_{b}`. +# +# Let the number of input rows :math:`M = 4` and :math:`\text{GROUP_SIZE_M} = 2`, +# here's a diagram of the parallel reduction strategy for :math:`\nabla_{w}` (:math:`\nabla_{b}` is omitted for brevity): +# +# .. image:: parallel_reduction.png +# +# In Stage 1, the rows of X that have the same color share the same buffer and thus a lock is used to ensure that only one kernel instance writes to the buffer at a time. +# In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`. +# In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`. + + +@triton.jit +def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + X, # pointer to the input + W, # pointer to the weights + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + Lock, # pointer to the lock + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + # Map the program id to the elements of X, DX, and DY it should compute. + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + DB = DB + lock_id * N + cols + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd + wdy = w * dy + xhat = tl.where(mask, xhat, 0.) + wdy = tl.where(mask, wdy, 0.) + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + # Accumulate partial sums for dw/db + partial_dw = (dy * xhat).to(w.dtype) + partial_db = (dy).to(w.dtype) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + # First store doesn't accumulate + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + partial_db += tl.load(DB, mask=mask) + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + # Release the lock + tl.atomic_xchg(Lock, 0) + + +@triton.jit +def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + FINAL_DW, # pointer to the weights gradient + FINAL_DB, # pointer to the biases gradient + M, # GROUP_SIZE_M + N, # number of columns + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + # Map the program id to the elements of DW and DB it should compute. + pid = tl.program_id(0) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Iterate through the rows of DW and DB to sum the partial sums. + for i in range(0, M, BLOCK_SIZE_M): + rows = i + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.) + db += tl.load(DB + offs, mask=mask, other=0.) + # Write the final sum to the output. + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) + tl.store(FINAL_DB + cols, sum_db, mask=cols < N) + + +# %% +# Benchmark +# --------- +# +# We can now compare the performance of our kernel against that of PyTorch. +# Here we focus on inputs that have Less than 64KB per feature. +# Specifically, one can set :code:`'mode': 'backward'` to benchmark the backward pass. + + +class LayerNorm(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, normalized_shape, weight, bias, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + ctx.save_for_backward(x, weight, bias, mean, rstd) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.eps = eps + return y + + @staticmethod + def backward(ctx, dy): + x, w, b, m, v = ctx.saved_tensors + # heuristics for amount of parallel reduction stream for DW/DB + N = w.shape[0] + GROUP_SIZE_M = 64 + if N <= 8192: GROUP_SIZE_M = 96 + if N <= 4096: GROUP_SIZE_M = 128 + if N <= 1024: GROUP_SIZE_M = 256 + # allocate output + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device) + _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + dw = torch.empty((N, ), dtype=w.dtype, device=w.device) + db = torch.empty((N, ), dtype=w.dtype, device=w.device) + dx = torch.empty_like(dy) + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + _layer_norm_bwd_dx_fused[(M, )]( # + dx, dy, _dw, _db, x, w, m, v, locks, # + x_arg.stride(0), N, # + BLOCK_SIZE_N=ctx.BLOCK_SIZE, # + GROUP_SIZE_M=GROUP_SIZE_M, # + num_warps=ctx.num_warps) + grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + # accumulate partial sums in separate kernel + _layer_norm_bwd_dwdb[grid]( + _dw, _db, dw, db, min(GROUP_SIZE_M, M), N, # + BLOCK_SIZE_M=32, # + BLOCK_SIZE_N=128, num_ctas=1) + return dx, None, dw, db, None + + +layer_norm = LayerNorm.apply + + +def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + # backward pass (triton) + y_tri.backward(dy, retain_graph=True) + dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]] + x.grad, weight.grad, bias.grad = None, None, None + # backward pass (torch) + y_ref.backward(dy, retain_graph=True) + dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]] + # compare + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0) + assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0) + assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], + x_vals=[512 * i for i in range(2, 32)], + line_arg='provider', + line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []), + line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), + styles=[('blue', '-'), ('green', '-'), ('orange', '-')], + ylabel='GB/s', + plot_name='layer-norm-backward', + args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, + )) +def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + quantiles = [0.5, 0.2, 0.8] + + def y_fwd(): + + if provider == "triton": + return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == "torch": + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == "apex": + apex_layer_norm = (apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)) + return apex_layer_norm(x) # noqa: F811, E704 + + # forward pass + if mode == 'forward': + gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) + # backward pass + if mode == 'backward': + y = y_fwd() + gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: F811, E704 + ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, + grad_to_none=[x], rep=500) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + + # test_layer_norm(1151, 8192, torch.float16) + # bench_layer_norm.run(save_path='.', print_data=True) + + +def simple_benchmark(): + M = 4096 + Ns = [512 * i for i in range(2, 10)] + dtype = torch.float16 + providers = ['triton', 'torch'] + (['apex'] if HAS_APEX else []) + eps = 1e-5 + device = 'cuda' + for N in Ns: + for provider in providers: + # Forward pass + for _ in range(10): + x_shape = (M, N) + w_shape = (N,) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + with proton.scope(f"forward [provider={provider}, N={N}]"): + if provider == "triton": + y = layer_norm(x, w_shape, weight, bias, eps) + elif provider == "torch": + y = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) + elif provider == "apex": + apex_layer_norm = (apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)) + y = apex_layer_norm(x) + else: + continue + # Backward pass + with proton.scope(f"backward [provider={provider}, N={N}]"): + if provider in ["triton", "torch", "apex"]: + y.backward(dy, retain_graph=True) + # Clear grads for next iteration + if x.grad is not None: + x.grad.zero_() + if weight.grad is not None: + weight.grad.zero_() + if bias.grad is not None: + bias.grad.zero_() + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + parser.add_argument("--pc_sampling", action="store_true") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark"): + simple_benchmark() + # prof.export_chrome_trace("layer_norm_torch.json") + with open("layer_norm_torch.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + backend = "cupti_pcsampling" if args.pc_sampling else "cupti" + proton.start(name="proton_layernorm", context="shadow", backend=backend) + simple_benchmark() + proton.finalize() + else: + print("Profiling with nsys") + simple_benchmark() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/kernel-microbench/05-layer-norm_results.csv b/kernel-microbench/05-layer-norm_results.csv new file mode 100644 index 0000000..f9a562f --- /dev/null +++ b/kernel-microbench/05-layer-norm_results.csv @@ -0,0 +1,6 @@ +run,script,baseline,proton,torch,nsys +1,05-layer-norm.py,8.81,9.05,8.92,9.50 +2,05-layer-norm.py,9.08,9.03,9.07,9.28 +3,05-layer-norm.py,8.84,8.86,9.05,9.46 +4,05-layer-norm.py,8.78,8.71,8.77,9.31 +5,05-layer-norm.py,9.28,9.45,9.16,9.58 diff --git a/kernel-microbench/06-fused-attention.py b/kernel-microbench/06-fused-attention.py new file mode 100644 index 0000000..a831956 --- /dev/null +++ b/kernel-microbench/06-fused-attention.py @@ -0,0 +1,726 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +# import pytest +import torch +import argparse +import triton.profiler as proton + +import triton +import triton.language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, fp8_v: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + if fp8_v: + p = p.to(tl.float8e5) + else: + p = p.to(tl.float16) + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# the code below and commenting out the equivalent parameters is convenient for +# re-tuning. +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64]\ + for s in ([1] if is_hip() else [3, 4, 7])\ + for w in [4, 8]\ +] + + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + + +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) +@triton.jit +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # + ): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(HEAD_DIM, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * N_CTX + off_m, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, start_m, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq(dq, q, K, V, # + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=False # + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _attn_fwd[grid]( + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + **extra_kern_args) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply + + +# @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) +# @pytest.mark.parametrize("causal", [True]) +# def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): +# torch.manual_seed(20) +# q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) +# k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) +# v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) +# sm_scale = 0.5 +# dout = torch.randn_like(q) +# # reference implementation +# M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) +# p = torch.matmul(q, k.transpose(2, 3)) * sm_scale +# if causal: +# p[:, :, M == 0] = float("-inf") +# p = torch.softmax(p.float(), dim=-1).half() +# # p = torch.exp(p) +# ref_out = torch.matmul(p, v) +# ref_out.backward(dout) +# ref_dv, v.grad = v.grad.clone(), None +# ref_dk, k.grad = k.grad.clone(), None +# ref_dq, q.grad = q.grad.clone(), None +# # triton implementation +# tri_out = attention(q, k, v, causal, sm_scale).half() +# tri_out.backward(dout) +# tri_dv, v.grad = v.grad.clone(), None +# tri_dk, k.grad = k.grad.clone(), None +# tri_dq, q.grad = q.grad.clone(), None +# # compare +# assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) +# rtol = 0.0 +# # Relative tolerance workaround for known hardware limitation of MI200 GPU. +# # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices +# if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": +# rtol = 1e-2 +# assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) +# assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) +# assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) + + +# try: +# from flash_attn.flash_attn_interface import \ +# flash_attn_qkvpacked_func as flash_attn_func +# HAS_FLASH = True +# except BaseException: +# HAS_FLASH = False + +# TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +# BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 +# # vary seq length for fixed head and batch=4 +# configs = [] +# for mode in ["fwd", "bwd"]: +# for causal in [True, False]: +# if mode == "bwd" and not causal: +# continue +# configs.append( +# triton.testing.Benchmark( +# x_names=["N_CTX"], +# x_vals=[2**i for i in range(10, 15)], +# line_arg="provider", +# line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + +# (["flash"] if HAS_FLASH else []), +# line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + +# (["Flash-2"] if HAS_FLASH else []), +# styles=[("red", "-"), ("blue", "-"), ("green", "-")], +# ylabel="TFLOPS", +# plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", +# args={ +# "H": N_HEADS, +# "BATCH": BATCH, +# "HEAD_DIM": HEAD_DIM, +# "mode": mode, +# "causal": causal, +# }, +# )) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): + assert mode in ["fwd", "bwd"] + dtype = torch.float16 + if "triton" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, causal, sm_scale) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops * 1e-12 / (ms * 1e-3) + + +def simple_benchmark(): + # Use the same configs as in the perf_report for argument ranges + TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') + try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + HAS_FLASH = True + except BaseException: + HAS_FLASH = False + BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 + nctx_vals = [2**i for i in range(10, 15)] + for mode in ["fwd", "bwd"]: + for causal in [True, False]: + if mode == "bwd" and not causal: + continue + providers = ["triton-fp16"] + if TORCH_HAS_FP8: + providers.append("triton-fp8") + if HAS_FLASH: + providers.append("flash") + for provider in providers: + for N_CTX in nctx_vals: + H = N_HEADS + BATCH_ = BATCH + HEAD_DIM_ = HEAD_DIM + for _ in range(20): # Run a few iterations for each config + dtype = torch.float16 + if "triton" in provider: + q = torch.randn((BATCH_, H, N_CTX, HEAD_DIM_), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH_, H, N_CTX, HEAD_DIM_), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH_, H, N_CTX, HEAD_DIM_), dtype=dtype, device="cuda", requires_grad=True) + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + sm_scale = 1.3 + if mode == "fwd": + with proton.scope("fused_attention_forward"): + attention(q, k, v, causal, sm_scale) + else: + with proton.scope("fused_attention_forward"): + o = attention(q, k, v, causal, sm_scale) + do = torch.randn_like(o) + with proton.scope("fused_attention_backward"): + o.backward(do, retain_graph=True) + elif provider == "flash": + qkv = torch.randn((BATCH_, N_CTX, 3, H, HEAD_DIM_), dtype=dtype, device="cuda", requires_grad=True) + if mode == "fwd": + with proton.scope("fused_attention_forward"): + flash_attn_func(qkv, causal=causal) + else: + with proton.scope("fused_attention_forward"): + o = flash_attn_func(qkv, causal=causal) + do = torch.randn_like(o) + with proton.scope("fused_attention_backward"): + o.backward(do, retain_graph=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="", help="Profiler to use: torch, proton, nsys") + parser.add_argument("--pc_sampling", action="store_true", help="Use PC sampling with proton") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark"): + simple_benchmark() + # prof.export_chrome_trace("fused_attention_torch.json") + with open("fused_attention_torch.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + backend = "cupti_pcsampling" if args.pc_sampling else "cupti" + proton.start(name="proton_fused_attention", context="shadow", backend=backend) + simple_benchmark() + proton.finalize() + else: + print("Profiling with nsys") + simple_benchmark() + + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/06-fused-attention_results.csv b/kernel-microbench/06-fused-attention_results.csv new file mode 100644 index 0000000..b47cb53 --- /dev/null +++ b/kernel-microbench/06-fused-attention_results.csv @@ -0,0 +1,6 @@ +run,script,baseline,proton,torch,nsys +1,06-fused-attention.py,26.28,27.14,34.70,27.76 +2,06-fused-attention.py,26.26,27.06,35.01,27.96 +3,06-fused-attention.py,25.95,26.79,34.09,28.22 +4,06-fused-attention.py,26.46,26.86,33.81,28.96 +5,06-fused-attention.py,26.48,26.99,34.12,27.35 diff --git a/kernel-microbench/07-grouped-gemm.py b/kernel-microbench/07-grouped-gemm.py new file mode 100644 index 0000000..2d0bbbe --- /dev/null +++ b/kernel-microbench/07-grouped-gemm.py @@ -0,0 +1,352 @@ +""" +Group GEMM +============================ +This group gemm kernel launches a fixed number of CTA to compute a group +of gemms. The scheduling is static and we do it on device. +""" + +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import torch +import argparse +import triton.profiler as proton + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), + ], + key=['group_size'], +) +@triton.jit +def grouped_matmul_kernel( + # device tensor of matrices pointers + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + # device tensor of gemm sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + group_gemm_sizes, + # device tensor of leading dimension sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + g_lds, + # number of gemms + group_size, + # number of virtual SM + NUM_SM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + # get the gemm size of the current problem + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + # iterate through the tiles in the current gemm problem + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): + # pick up a tile from the current gemm problem + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + # figure out tile coordinates + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + # do regular gemm here + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + # hint to Triton compiler to do proper loop pipelining + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + # assume full tile for now + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + # assumes full tile for now + tl.store(c_ptrs, c) + + # go to the next tile by advancing NUM_SM + tile_idx += NUM_SM + + # get ready to go to the next gemm problem + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_fn(group_A, group_B): + device = torch.device('cuda') + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=device, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + # note these are device tensors + d_a_ptrs = torch.tensor(A_addrs, device=device) + d_b_ptrs = torch.tensor(B_addrs, device=device) + d_c_ptrs = torch.tensor(C_addrs, device=device) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + # we use a fixed number of CTA, and it's auto-tunable + grid = lambda META: (META['NUM_SM'], ) + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + ) + + return group_C + + +group_m = [1024, 512, 256, 128] +group_n = [1024, 512, 256, 128] +group_k = [1024, 512, 256, 128] +group_A = [] +group_B = [] +assert len(group_m) == len(group_n) +assert len(group_n) == len(group_k) +group_size = len(group_m) +for i in range(group_size): + M = group_m[i] + N = group_n[i] + K = group_k[i] + A = torch.rand((M, K), device="cuda", dtype=torch.float16) + B = torch.rand((K, N), device="cuda", dtype=torch.float16) + group_A.append(A) + group_B.append(B) + +tri_out = group_gemm_fn(group_A, group_B) +ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] +for i in range(group_size): + assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0) + + +# only launch the kernel, no tensor preparation here to remove all overhead +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + grid = lambda META: (META['NUM_SM'], ) + grouped_matmul_kernel[grid]( + a_ptrs, + b_ptrs, + c_ptrs, + sizes, + lds, + group_size, + ) + + +def torch_perf_fn(group_A, group_B): + for a, b in zip(group_A, group_B): + torch.matmul(a, b) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['N'], + x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name` + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['cublas', 'triton'], + # label name for the lines + line_names=["cuBLAS", "Triton"], + # line styles + styles=[('green', '-'), ('blue', '-')], + ylabel="runtime(ms)", # label name for the y-axis + plot_name="group-gemm-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(N, provider): + group_size = 4 + group_A = [] + group_B = [] + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = torch.rand((N, N), device="cuda", dtype=torch.float16) + B = torch.rand((N, N), device="cuda", dtype=torch.float16) + C = torch.empty((N, N), device="cuda", dtype=torch.float16) + group_A.append(A) + group_B.append(B) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device="cuda") + d_b_ptrs = torch.tensor(B_addrs, device="cuda") + d_c_ptrs = torch.tensor(C_addrs, device="cuda") + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda") + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + if provider == 'cublas': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) + return ms, max_ms, min_ms + + +def simple_benchmark(): + # Use the same configs as in the perf_report for argument ranges + sizes = [2 ** i for i in range(7, 14)] # N values from perf_report + providers = ["cublas", "triton"] + group_size = 4 + for provider in providers: + for N in sizes: + for _ in range(500): # Run a few iterations for each config + group_A = [] + group_B = [] + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = torch.rand((N, N), device="cuda", dtype=torch.float16) + B = torch.rand((N, N), device="cuda", dtype=torch.float16) + C = torch.empty((N, N), device="cuda", dtype=torch.float16) + group_A.append(A) + group_B.append(B) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device="cuda") + d_b_ptrs = torch.tensor(B_addrs, device="cuda") + d_c_ptrs = torch.tensor(C_addrs, device="cuda") + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda") + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda") + + if provider == 'cublas': + torch_perf_fn(group_A, group_B) + elif provider == 'triton': + triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + parser.add_argument("--pc_sampling", action="store_true") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark"): + simple_benchmark() + with open("grouped_gemm_torch.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + backend = "cupti_pcsampling" if args.pc_sampling else "cupti" + proton.start(name="proton_grouped_gemm", context="shadow", backend=backend) + simple_benchmark() + proton.finalize() + else: + print("Profiling with nsys") + simple_benchmark() + + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/08-grouped-gemm_results.csv b/kernel-microbench/08-grouped-gemm_results.csv new file mode 100644 index 0000000..0939260 --- /dev/null +++ b/kernel-microbench/08-grouped-gemm_results.csv @@ -0,0 +1,6 @@ +run,script,baseline,proton,torch,nsys +1,08-grouped-gemm.py,8.82,9.82,9.08,9.17 +2,08-grouped-gemm.py,8.71,8.64,8.75,9.19 +3,08-grouped-gemm.py,8.78,9.10,8.79,9.20 +4,08-grouped-gemm.py,9.85,8.54,8.59,9.01 +5,08-grouped-gemm.py,8.56,8.86,8.70,9.21 diff --git a/kernel-microbench/12-liger-rms-norm.py b/kernel-microbench/12-liger-rms-norm.py new file mode 100644 index 0000000..845f071 --- /dev/null +++ b/kernel-microbench/12-liger-rms-norm.py @@ -0,0 +1,50 @@ +import argparse +import triton.profiler as proton + +from typing import Optional + +import torch + +from liger_kernel.ops.rms_norm import rms_norm_forward +""" +the header defined in liger_kernel +def rms_norm_forward(X, W, eps, offset, casting_mode): +""" + +def simple_benchmark_rms_norm(): + sizes = [2 ** i for i in range(10, 16)] + device = "cuda" + for size in sizes: + batch = size + num_features = size // 2 + for _ in range(500): + X = torch.rand((batch, num_features), device=device, dtype=torch.float32) + W = torch.rand((num_features,), device=device, dtype=torch.float32) + eps = 1e-5 + offset = 0 # default offset + casting_mode = 0 # default casting mode, update if needed + rms_norm_forward(X, W, eps, offset, casting_mode) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_rms_norm"): + simple_benchmark_rms_norm() + with open("rms_norm_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + proton.start(name="proton_rms_norm", context="shadow", backend="cupti") + simple_benchmark_rms_norm() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_rms_norm() + +if __name__ == "__main__": + main() + diff --git a/kernel-microbench/13-liger-tvd.py b/kernel-microbench/13-liger-tvd.py new file mode 100644 index 0000000..f4c8fbf --- /dev/null +++ b/kernel-microbench/13-liger-tvd.py @@ -0,0 +1,58 @@ +import argparse +import triton.profiler as proton + +from typing import Optional + +import torch + +from liger_kernel.ops.tvd import tv_distance_forward_triton + +""" +the header defined in liger_kernel +def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) +""" + +def simple_benchmark_tvd(): + sizes = [2 ** i for i in range(10, 16)] + device = "cuda" + for size in sizes: + batch = size + num_classes = size // 2 + for _ in range(500): + p = torch.rand((batch, num_classes), device=device, dtype=torch.float32) + q = torch.rand((batch, num_classes), device=device, dtype=torch.float32) + shift_labels = False + reduction = 1 # 0: none, 1: mean, 2: sum (typical torch convention) + ignore_index = -100 + has_label = False + tv_distance_forward_triton( + p, + q, + shift_labels, + reduction, + ignore_index, + has_label, + ) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_tvd"): + simple_benchmark_tvd() + with open("tvd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + proton.start(name="proton_tvd", context="shadow", backend="cupti") + simple_benchmark_tvd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_tvd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/14-liger-fused-linear-cross-entropy-bwd.py b/kernel-microbench/14-liger-fused-linear-cross-entropy-bwd.py new file mode 100644 index 0000000..d0e314a --- /dev/null +++ b/kernel-microbench/14-liger-fused-linear-cross-entropy-bwd.py @@ -0,0 +1,77 @@ +import argparse +import triton.profiler as proton + +from typing import Optional + +import torch + +from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward + +""" +the header defined in liger_kernel +def cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, +): +""" + +def simple_benchmark_cross_entropy(): + # Example parameter sets for cross_entropy_forward + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + num_classes = size // 2 + for _ in range(500): + _input = torch.randn((batch, num_classes), device=device, dtype=torch.float32, requires_grad=True) + target = torch.randint(0, num_classes, (batch,), device=device, dtype=torch.int64) + weight = torch.rand(num_classes, device=device, dtype=torch.float32) + ignore_index = -100 + lse_square_scale = 0.0 + label_smoothing = 0.0 + reduction = 1 # 0: none, 1: mean, 2: sum (typical torch convention) + softcap = 0.0 + return_z_loss = False + # Call the kernel + fused_linear_cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_cross_entropy"): + simple_benchmark_cross_entropy() + with open("cross_entropy_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + proton.start(name="proton_cross_entropy", context="shadow", backend="cupti") + simple_benchmark_cross_entropy() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_cross_entropy() + +if __name__ == "__main__": + main() + diff --git a/kernel-microbench/15-liger-jsd.py b/kernel-microbench/15-liger-jsd.py new file mode 100644 index 0000000..b9d1246 --- /dev/null +++ b/kernel-microbench/15-liger-jsd.py @@ -0,0 +1,58 @@ +import argparse +import triton.profiler as proton + +from typing import Optional + +import torch + +from liger_kernel.ops.jsd import jsd_forward + +""" +the header defined in liger_kernel +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): +""" + +def simple_benchmark_jsd(): + sizes = [2 ** i for i in range(10, 16)] + device = "cuda" + for size in sizes: + batch = size + num_classes = size // 2 + for _ in range(500): + _input = torch.rand((batch, num_classes), device=device, dtype=torch.float32) + target = torch.rand((batch, num_classes), device=device, dtype=torch.float32) + shift_labels = False + beta = 1.0 + ignore_index = -100 + has_label = False + jsd_forward( + _input, + target, + shift_labels, + beta, + ignore_index, + has_label, + ) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_jsd"): + simple_benchmark_jsd() + with open("jsd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + proton.start(name="proton_jsd", context="shadow", backend="cupti") + simple_benchmark_jsd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_jsd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/16-liger-grpo-loss-fwd.py b/kernel-microbench/16-liger-grpo-loss-fwd.py new file mode 100644 index 0000000..4b0dbe0 --- /dev/null +++ b/kernel-microbench/16-liger-grpo-loss-fwd.py @@ -0,0 +1,86 @@ +import argparse +import triton.profiler as proton + +from typing import Optional + +import torch + +from liger_kernel.ops.grpo_loss import GrpoLossFunction + +""" +the header defined in liger_kernel +class GrpoLossFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace, + ): +""" + +def simple_benchmark_grpo_loss(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + vocab_size = 32000 + for size in sizes: + batch = size + seq_len = size // 2 + for _ in range(5): + logits = torch.randn((batch, seq_len, vocab_size), device=device, dtype=torch.float32, requires_grad=True) + old_logp = torch.randn((batch, seq_len), device=device, dtype=torch.float32) + ref_logp = torch.randn((batch, seq_len), device=device, dtype=torch.float32) + completion_ids = torch.randint(0, vocab_size, (batch, seq_len), device=device, dtype=torch.int64) + advantages = torch.randn((batch, seq_len), device=device, dtype=torch.float32) + completion_mask = torch.randint(0, 2, (batch, seq_len), device=device, dtype=torch.int32) + temperature = 1.0 + beta = 1.0 + eps_low = 1e-6 + eps_high = 1e-2 + inplace = False + # Call the kernel + GrpoLossFunction.apply( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace, + ) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_grpo_loss"): + simple_benchmark_grpo_loss() + with open("grpo_loss_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + proton.start(name="proton_grpo_loss", context="shadow", backend="cupti") + simple_benchmark_grpo_loss() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_grpo_loss() + +if __name__ == "__main__": + main() + diff --git a/kernel-microbench/17-liger-kl-div.py b/kernel-microbench/17-liger-kl-div.py new file mode 100644 index 0000000..7656baf --- /dev/null +++ b/kernel-microbench/17-liger-kl-div.py @@ -0,0 +1,53 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl +import triton.profiler as proton + +from liger_kernel.ops.kl_div import kldiv_forward_triton + +""" +the header defined in liger_kernel +def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] + +""" + +def simple_benchmark_kldiv(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + vocab_size = size // 2 + for _ in range(500): + y_pred = torch.randn((batch, vocab_size), device=device, dtype=torch.float32) + y_true = torch.randn((batch, vocab_size), device=device, dtype=torch.float32) + log_target = False + reduction = 1 # 0: none, 1: mean, 2: sum + eps = 1e-6 + kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_kldiv"): + simple_benchmark_kldiv() + with open("kldiv_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + proton.start(name="proton_kldiv", context="shadow", backend="cupti") + simple_benchmark_kldiv() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_kldiv() + +if __name__ == "__main__": + main() + diff --git a/kernel-microbench/18-liger-kl-div-bwd.py b/kernel-microbench/18-liger-kl-div-bwd.py new file mode 100644 index 0000000..c2cb929 --- /dev/null +++ b/kernel-microbench/18-liger-kl-div-bwd.py @@ -0,0 +1,52 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl +import triton.profiler as proton + +from liger_kernel.ops.kl_div import kldiv_backward_triton + +""" +the header defined in liger_kernel +def kldiv_backward_triton(target, grad_output, new_grads, log_target): + +""" + +def simple_benchmark_kldiv_bwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + vocab_size = size // 2 + for _ in range(500): + target = torch.randn((batch, vocab_size), device=device, dtype=torch.float32) + grad_output = torch.randn((batch, vocab_size), device=device, dtype=torch.float32) + new_grads = torch.empty((batch, vocab_size), device=device, dtype=torch.float32) + log_target = False + kldiv_backward_triton(target, grad_output, new_grads, log_target) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_kldiv_bwd"): + simple_benchmark_kldiv_bwd() + with open("kldiv_bwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_kldiv_bwd", context="shadow", backend="cupti") + simple_benchmark_kldiv_bwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_kldiv_bwd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/19-liger-geglu.py b/kernel-microbench/19-liger-geglu.py new file mode 100644 index 0000000..a90d7a5 --- /dev/null +++ b/kernel-microbench/19-liger-geglu.py @@ -0,0 +1,52 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous + +from liger_kernel.ops.geglu import geglu_forward + +""" +the header defined in liger_kernel +def geglu_forward(a, b): +""" + +def simple_benchmark_geglu_fwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + hidden = size // 2 + for _ in range(500): + a = torch.randn((batch, hidden), device=device, dtype=torch.float32) + b = torch.randn((batch, hidden), device=device, dtype=torch.float32) + geglu_forward(a, b) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_geglu_fwd"): + simple_benchmark_geglu_fwd() + with open("geglu_fwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_geglu_fwd", context="shadow", backend="cupti") + simple_benchmark_geglu_fwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_geglu_fwd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/20-liger-geglu-bwd.py b/kernel-microbench/20-liger-geglu-bwd.py new file mode 100644 index 0000000..ff1a40f --- /dev/null +++ b/kernel-microbench/20-liger-geglu-bwd.py @@ -0,0 +1,49 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.geglu import geglu_backward + +""" +the header defined in liger_kernel +def geglu_backward(a, b, dc): +""" + +def simple_benchmark_geglu_bwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + hidden = size // 2 + for _ in range(500): + a = torch.randn((batch, hidden), device=device, dtype=torch.float32) + b = torch.randn((batch, hidden), device=device, dtype=torch.float32) + dc = torch.randn((batch, hidden), device=device, dtype=torch.float32) + geglu_backward(a, b, dc) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_geglu_bwd"): + simple_benchmark_geglu_bwd() + with open("geglu_bwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_geglu_bwd", context="shadow", backend="cupti") + simple_benchmark_geglu_bwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_geglu_bwd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/21-liger-group-norm-bwd.py b/kernel-microbench/21-liger-group-norm-bwd.py new file mode 100644 index 0000000..5f6d7fe --- /dev/null +++ b/kernel-microbench/21-liger-group-norm-bwd.py @@ -0,0 +1,55 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.group_norm import group_norm_backward + +""" +the header defined in liger_kernel +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): +""" + +def simple_benchmark_group_norm_bwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + num_channels = size // 2 + num_groups = max(1, num_channels // 8) + seq_len = 32 + shape = (batch, num_channels, seq_len) + for _ in range(500): + dY = torch.randn(shape, device=device, dtype=torch.float32) + X = torch.randn(shape, device=device, dtype=torch.float32) + W = torch.randn(num_channels, device=device, dtype=torch.float32) + B = torch.randn(num_channels, device=device, dtype=torch.float32) + Mean = torch.randn((batch, num_groups), device=device, dtype=torch.float32) + RSTD = torch.randn((batch, num_groups), device=device, dtype=torch.float32) + group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_group_norm_bwd"): + simple_benchmark_group_norm_bwd() + with open("group_norm_bwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_group_norm_bwd", context="shadow", backend="cupti") + simple_benchmark_group_norm_bwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_group_norm_bwd() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/kernel-microbench/22-liger-group-norm-fwd.py b/kernel-microbench/22-liger-group-norm-fwd.py new file mode 100644 index 0000000..379bf89 --- /dev/null +++ b/kernel-microbench/22-liger-group-norm-fwd.py @@ -0,0 +1,55 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.group_norm import group_norm_forward + +""" +the header defined in liger_kernel +def group_norm_forward(X, num_channels, num_groups, W, B, eps): + +""" + +def simple_benchmark_group_norm_fwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + num_channels = size // 2 + num_groups = max(1, num_channels // 8) + seq_len = 32 + shape = (batch, num_channels, seq_len) + for _ in range(500): + X = torch.randn(shape, device=device, dtype=torch.float32) + W = torch.randn(num_channels, device=device, dtype=torch.float32) + B = torch.randn(num_channels, device=device, dtype=torch.float32) + eps = 1e-5 + group_norm_forward(X, num_channels, num_groups, W, B, eps) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_group_norm_fwd"): + simple_benchmark_group_norm_fwd() + with open("group_norm_fwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_group_norm_fwd", context="shadow", backend="cupti") + simple_benchmark_group_norm_fwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_group_norm_fwd() + +if __name__ == "__main__": + main() + diff --git a/kernel-microbench/23-qwen2-mrope.py b/kernel-microbench/23-qwen2-mrope.py new file mode 100644 index 0000000..8933acc --- /dev/null +++ b/kernel-microbench/23-qwen2-mrope.py @@ -0,0 +1,53 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.qwen2vl_mrope import qwen2vl_mrope_forward + +""" +the header defined in liger_kernel +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): + +""" + +def simple_benchmark_qwen2_mrope(): + sizes = [2 ** i for i in range(7, 11)] + device = "cuda" + for size in sizes: + batch = size + seq_len = size + n_q_head = 16 + n_kv_head = 8 + head_dim = 128 + q = torch.randn((batch, seq_len, n_q_head, head_dim), device=device, dtype=torch.float32) + k = torch.randn((batch, seq_len, n_kv_head, head_dim), device=device, dtype=torch.float32) + cos = torch.randn((seq_len, head_dim), device=device, dtype=torch.float32) + sin = torch.randn((seq_len, head_dim), device=device, dtype=torch.float32) + mrope_section = (0, seq_len) + for _ in range(500): + qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_qwen2_mrope"): + simple_benchmark_qwen2_mrope() + with open("qwen2_mrope_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_qwen2_mrope", context="shadow", backend="cupti") + simple_benchmark_qwen2_mrope() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_qwen2_mrope() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/24-liger-sparsemax.py b/kernel-microbench/24-liger-sparsemax.py new file mode 100644 index 0000000..bde308c --- /dev/null +++ b/kernel-microbench/24-liger-sparsemax.py @@ -0,0 +1,48 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.sparsemax import LigerSparsemaxFunction + +""" +the header defined in liger_kernel +class LigerSparsemaxFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x: torch.Tensor, dim: int): +""" + +def simple_benchmark_sparsemax_fwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + dim = size // 2 + for _ in range(500): + x = torch.randn((batch, dim), device=device, dtype=torch.float32) + LigerSparsemaxFunction.apply(x, 1) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_sparsemax_fwd"): + simple_benchmark_sparsemax_fwd() + with open("sparsemax_fwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_sparsemax_fwd", context="shadow", backend="cupti") + simple_benchmark_sparsemax_fwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_sparsemax_fwd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/25-liger-dyt.py b/kernel-microbench/25-liger-dyt.py new file mode 100644 index 0000000..fb65d70 --- /dev/null +++ b/kernel-microbench/25-liger-dyt.py @@ -0,0 +1,49 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.dyt import liger_dyt_fwd + +""" +the header defined in liger_kernel +def liger_dyt_fwd(x, alpha, gamma, beta): + +""" + +def simple_benchmark_liger_dyt_fwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + dim = size // 2 + for _ in range(500): + x = torch.randn((batch, dim), device=device, dtype=torch.float32) + alpha = torch.randn(dim, device=device, dtype=torch.float32) + gamma = torch.randn(dim, device=device, dtype=torch.float32) + beta = torch.randn(dim, device=device, dtype=torch.float32) + liger_dyt_fwd(x, alpha, gamma, beta) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_liger_dyt_fwd"): + simple_benchmark_liger_dyt_fwd() + with open("liger_dyt_fwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_liger_dyt_fwd", context="shadow", backend="cupti") + simple_benchmark_liger_dyt_fwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_liger_dyt_fwd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/26-liger-dyt-bwd.py b/kernel-microbench/26-liger-dyt-bwd.py new file mode 100644 index 0000000..9a221bb --- /dev/null +++ b/kernel-microbench/26-liger-dyt-bwd.py @@ -0,0 +1,51 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.dyt import liger_dyt_bwd + +""" +the header defined in liger_kernel +def liger_dyt_bwd(dy, x, alpha, gamma, beta): + +""" + +def simple_benchmark_liger_dyt_bwd(): + sizes = [2 ** i for i in range(8, 14)] + device = "cuda" + for size in sizes: + batch = size + dim = size // 2 + x = torch.randn((batch, dim), device=device, dtype=torch.float32) + alpha = torch.randn(dim, device=device, dtype=torch.float32) + gamma = torch.randn(dim, device=device, dtype=torch.float32) + beta = torch.randn(dim, device=device, dtype=torch.float32) + dy = torch.randn((batch, dim), device=device, dtype=torch.float32) + + for _ in range(500): + liger_dyt_bwd(dy, x, alpha, gamma, beta) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_liger_dyt_bwd"): + simple_benchmark_liger_dyt_bwd() + with open("liger_dyt_bwd_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_liger_dyt_bwd", context="shadow", backend="cupti") + simple_benchmark_liger_dyt_bwd() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_liger_dyt_bwd() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/README.rst b/kernel-microbench/README.rst new file mode 100644 index 0000000..1dfa5f4 --- /dev/null +++ b/kernel-microbench/README.rst @@ -0,0 +1,11 @@ +Tutorials +========= + +Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one. + +To install the dependencies for the tutorials: + +.. code-block:: bash + + cd triton + pip install -e './python[tutorials]' diff --git a/kernel-microbench/all_results.csv b/kernel-microbench/all_results.csv new file mode 100644 index 0000000..aec07dc --- /dev/null +++ b/kernel-microbench/all_results.csv @@ -0,0 +1,36 @@ +run,script,baseline,proton,torch,nsys +1,01-vector-add.py,7.68,8.15,7.86,8.78 +2,01-vector-add.py,7.44,8.01,8.24,8.35 +3,01-vector-add.py,7.32,7.75,7.83,8.33 +4,01-vector-add.py,7.60,7.86,8.05,8.36 +5,01-vector-add.py,7.53,8.18,7.59,8.14 +1,02-fused-softmax.py,9.39,9.69,9.78,10.17 +2,02-fused-softmax.py,9.26,9.64,10.15,10.40 +3,02-fused-softmax.py,9.81,10.13,10.05,10.42 +4,02-fused-softmax.py,9.32,9.66,10.07,10.24 +5,02-fused-softmax.py,9.26,9.67,9.88,10.44 +1,03-matrix-multiplication.py,12.99,12.52,13.24,14.19 +2,03-matrix-multiplication.py,12.79,12.27,13.66,14.76 +3,03-matrix-multiplication.py,12.52,12.36,13.89,14.50 +4,03-matrix-multiplication.py,12.82,12.90,13.54,14.57 +5,03-matrix-multiplication.py,12.95,12.24,13.40,15.12 +1,05-layer-norm.py,8.81,9.05,8.92,9.50 +2,05-layer-norm.py,9.08,9.03,9.07,9.28 +3,05-layer-norm.py,8.84,8.86,9.05,9.46 +4,05-layer-norm.py,8.78,8.71,8.77,9.31 +5,05-layer-norm.py,9.28,9.45,9.16,9.58 +1,04-low-memory-dropout.py,8.21,8.93,7.91,8.54 +2,04-low-memory-dropout.py,8.50,8.09,8.54,8.98 +3,04-low-memory-dropout.py,7.61,8.45,8.26,8.52 +4,04-low-memory-dropout.py,8.07,8.44,8.22,8.32 +5,04-low-memory-dropout.py,7.71,8.23,8.10,8.56 +1,06-fused-attention.py,26.28,27.14,34.70,27.76 +2,06-fused-attention.py,26.26,27.06,35.01,27.96 +3,06-fused-attention.py,25.95,26.79,34.09,28.22 +4,06-fused-attention.py,26.46,26.86,33.81,28.96 +5,06-fused-attention.py,26.48,26.99,34.12,27.35 +1,08-grouped-gemm.py,8.82,9.82,9.08,9.17 +2,08-grouped-gemm.py,8.71,8.64,8.75,9.19 +3,08-grouped-gemm.py,8.78,9.10,8.79,9.20 +4,08-grouped-gemm.py,9.85,8.54,8.59,9.01 +5,08-grouped-gemm.py,8.56,8.86,8.70,9.21 diff --git a/kernel-microbench/benchmarks.sh b/kernel-microbench/benchmarks.sh new file mode 100755 index 0000000..22a87a4 --- /dev/null +++ b/kernel-microbench/benchmarks.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# Runs each profiler N times, times each, and prints the results + +# Number of trials +NUM_TRIALS=5 + +# Create a timestamped results file +TIMESTAMP=$(date +%s) +RESULTS_FILE="results_${TIMESTAMP}.csv" +echo "run,script,baseline,proton,torch,nsys" > "$RESULTS_FILE" +# also include nn-*.py scripts in tk directory +for SCRIPT in [0-9]*-*.py tk/[0-9]*-*.py; do + + BASELINE_TIMES=() + PROTON_TIMES=() + TORCH_TIMES=() + NSYS_TIMES=() + + echo -e "\n===== Benchmarking $SCRIPT =====" + + cd "$(dirname "$0")" + # run once to warm up + python "$SCRIPT" 2>&1 >/dev/null + + echo "No Profiling for baseline" + for ((i=1; i<=NUM_TRIALS; i++)); do + t=$(/usr/bin/time -f "%e" python "$SCRIPT" 2>&1 >/dev/null) + sleep 3 + BASELINE_TIMES+=("$t") + echo "Run $i: $t s" + done + + echo -e "\nProfiling with --profiler proton..." + for ((i=1; i<=NUM_TRIALS; i++)); do + t=$(/usr/bin/time -f "%e" python "$SCRIPT" --profiler proton 2>&1 >/dev/null) + sleep 3 + PROTON_TIMES+=("$t") + echo "Run $i: $t s" + done + + echo -e "\nProfiling with --profiler torch..." + for ((i=1; i<=NUM_TRIALS; i++)); do + t=$(/usr/bin/time -f "%e" python "$SCRIPT" --profiler torch 2>&1 >/dev/null) + sleep 3 + TORCH_TIMES+=("$t") + echo "Run $i: $t s" + done + + echo -e "\nProfiling with nsys..." + for ((i=1; i<=NUM_TRIALS; i++)); do + t=$(/usr/bin/time -f "%e" bash -c "nsys profile --trace=cuda --sample=none --cpuctxsw=none python \"$SCRIPT\" > /dev/null" 2>&1) + sleep 3 + NSYS_TIMES+=("$t") + echo "Run $i: $t s" + done + # delete nsys-rep files + rm -f *.nsys-rep + + # Print results + + echo -e "\n==== Results (seconds) ====" + echo "baseline: [${BASELINE_TIMES[*]}]" + echo "proton: [${PROTON_TIMES[*]}]" + echo "torch: [${TORCH_TIMES[*]}]" + echo "nsys: [${NSYS_TIMES[*]}]" + Append results to the global CSV file + for ((i=0; i> "$RESULTS_FILE" + done + echo -e "\n==== Results in CSV format ====" + echo "run,script, baseline,proton,torch,nsys" + for ((i=0; i> x_nbits) +# y = tl.sort(y, dim=1) +# y_indices = y >> x_nbits +# y_values = (y & ((1 << x_nbits) - 1)).to(x_utype).to(x_dtype, bitcast=True) +# y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype) + +# # write back +# offs_y_n = tl.arange(0, N_EXPTS_ACT) +# Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :] +# Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :] +# tl.store(Yv_ptrs, y_values, mask=mask_m) +# tl.store(Yi_ptrs, y_indices, mask=mask_m) + +# # pack into bitmatrix +# y_div = y_indices // 32 +# y_rem = y_indices % 32 +# loop_iterations = N_EXPTS_PAD // BLOCK_N +# for i in range(loop_iterations): +# offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) +# y2 = tl.where(y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0) +# r = tl.reduce_or(y2, axis=1) +# BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] +# tl.store(BitsPtrs, r, mask=mask_m) + +# def topk(x, k, dim=1, return_bitmatrix=True): +# cdiv = lambda a, b: (a + b - 1) // b +# BLOCK_M = 8 +# BLOCK_N = 128 +# assert x.ndim == 2 +# assert x.shape[-1] < 32768 +# assert dim == 1 +# assert return_bitmatrix +# n_rows, n_cols = x.shape +# dev = x.device +# n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N +# n_cols_words = n_cols_pad // 32 +# # scratchpad tensors +# # NOTE: these are not returned +# y_vals = torch.empty((n_rows, k), dtype=x.dtype, device=dev) +# y_indx = torch.empty((n_rows, k), dtype=torch.int16, device=dev) +# bitmatrix = torch.empty((n_rows, n_cols_words), dtype=torch.uint32, device=dev) +# _topk[(cdiv(n_rows, BLOCK_M), )]( +# x, x.stride(0), # inputs +# y_vals, y_indx, y_vals.stride(0), # output [topk] +# bitmatrix, bitmatrix.stride(0), # output [bitmatrix] +# n_rows, n_cols, # shapes +# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter +# N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants +# ) +# return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows, n_cols]) + + + +def simple_benchmark_topk(): + sizes = [128 * i for i in range(2, 6)] + ks = [4, 8, 16] + device = "cuda" + for size in sizes: + for k in ks: + x = torch.randn((size, size), device=device, dtype=torch.float16) + for _ in range(500): + # Triton topk + topk(x, k) + # Optionally, compare to torch's topk for correctness or reference + # torch.topk(x, k, dim=1) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_topk"): + simple_benchmark_topk() + # prof.export_chrome_trace("topk_trace.json") + with open("topk_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + proton.start(name="proton_topk", context="shadow", backend="cupti") + simple_benchmark_topk() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_topk() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/kernel-microbench/tk/09-compaction.py b/kernel-microbench/tk/09-compaction.py new file mode 100644 index 0000000..9b4ef2e --- /dev/null +++ b/kernel-microbench/tk/09-compaction.py @@ -0,0 +1,51 @@ +import torch +from triton_kernels.compaction import compaction +from triton_kernels import Bitmatrix + + +def simple_benchmark_compaction(): + sizes = [(2 ** i, 2**i) for i in range(8, 14)] + device = "cuda" + k = 16 # or another reasonable value for k + p = 0.5 # probability to keep an index + for n_tokens, n_cols in sizes: + for _ in range(100): + yi = torch.rand((n_tokens, n_cols), device=device).argsort(dim=-1) + yi = yi[:, :k].to(torch.int32) + yv = torch.randn((n_tokens, k), dtype=torch.bfloat16, device=device) + mask = torch.zeros((n_tokens, n_cols), dtype=torch.int32, device=device) + keep = (torch.rand(yi.shape, device=device) < p) + if keep.any(): + rows = torch.arange(yi.size(0), device=device).unsqueeze(1).expand_as(yi) + mask[rows[keep], yi[keep]] = 1 + chunks = mask.view(*mask.shape[:-1], -1, 32) + weights = (1 << torch.arange(32, dtype=torch.int32, device=device)) + bitmask = (chunks.int() * weights).sum(dim=-1) + compaction(yv, yi, bitmask) + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_compaction"): + simple_benchmark_compaction() + with open("compaction_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_compaction", context="shadow", backend="cupti") + simple_benchmark_compaction() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_compaction() + +if __name__ == "__main__": + main() + + diff --git a/kernel-microbench/tk/10-swiglu.py b/kernel-microbench/tk/10-swiglu.py new file mode 100644 index 0000000..17b1e12 --- /dev/null +++ b/kernel-microbench/tk/10-swiglu.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass, field +from triton_kernels.numerics import InFlexData, OutFlexData +import torch +import triton +from triton_kernels.swiglu import swiglu + + + + +def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): + # the reference implementation and the triton implementation do not tie-break experts the same way + randbits = [torch.randperm(n_expts_tot) for _ in range(n_tokens)] + x = [(-1)**i * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(dtype)) for i, bits in enumerate(randbits)] + return torch.stack(x).to(device=device) + + +def alloc_rand(shape, device, dtype, requires_grad=True): + if dtype.itemsize == 1: + tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) + return tmp.to(dtype).requires_grad_(requires_grad) + return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + + +@dataclass(frozen=True) +class FlexCtx: + out_data: OutFlexData = OutFlexData() + inp_data: InFlexData = InFlexData() + saturate_inf: bool = False + + +@dataclass(frozen=True) +class PrecisionConfig: + limit: float + flex_ctx: FlexCtx = FlexCtx() + + +def simple_benchmark_swiglu(): + device = "cuda" + n_expts_tot = 6 + n_expts_act = 2 + alpha = 0.5 + limit = 10 + sizes = [ + (256, 1024, 256), + (512, 2048, 512), + (1024, 4096, 1024), + (1311, 4352, 1311), + ] + for M, N, n_tokens in sizes: + torch.manual_seed(2) + logits = init_data(M, n_expts_tot).detach() + x = alloc_rand([n_tokens, N], device=device, dtype=torch.bfloat16) + precision_config = PrecisionConfig(limit=limit) + for _ in range(100): + swiglu(x, alpha, precision_config) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_swiglu"): + simple_benchmark_swiglu() + with open("swiglu_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_swiglu", context="shadow", backend="cupti") + simple_benchmark_swiglu() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_swiglu() + + + + +if __name__ == "__main__": + main() + + diff --git a/kernel-microbench/tk/11-routing.py b/kernel-microbench/tk/11-routing.py new file mode 100644 index 0000000..f3f3e04 --- /dev/null +++ b/kernel-microbench/tk/11-routing.py @@ -0,0 +1,46 @@ +import torch + +from triton_kernels.routing import routing + +def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): + # the reference implementation and the triton implementation do not tie-break experts the same way + randbits = [torch.randperm(n_expts_tot) for _ in range(n_tokens)] + x = [(-1)**i * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(dtype)) for i, bits in enumerate(randbits)] + return torch.stack(x).to(device=device) + + +def simple_benchmark_routing(device="cuda"): + n_tokens = 8192 + block_m = 128 + n_expts_tot, n_expts_act = 128, 4 + tri_logits = init_data(n_tokens, n_expts_tot, device=device).detach() + for i in range(500): + torch.manual_seed(i) + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) + # tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m) # noqa: F841 + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profiler", type=str, default="") + args = parser.parse_args() + if args.profiler == "torch": + print("Profiling with torch") + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("benchmark_routing"): + simple_benchmark_routing() + with open("routing_trace.json", "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000).__str__()) + elif args.profiler == "proton": + print("Profiling with proton") + import triton.profiler as proton + proton.start(name="proton_routing", context="shadow", backend="cupti") + simple_benchmark_routing() + proton.finalize() + else: + print("Profiling with nsys (no-op fallback)") + simple_benchmark_routing() + +if __name__ == "__main__": + main() diff --git a/kernel-microbench/tk/bench/bench_mlp.py b/kernel-microbench/tk/bench/bench_mlp.py new file mode 100644 index 0000000..8130f47 --- /dev/null +++ b/kernel-microbench/tk/bench/bench_mlp.py @@ -0,0 +1,226 @@ +from pathlib import Path +import matplotlib.pyplot as plt +import json +import triton.profiler as proton +import torch +import triton_kernels.swiglu +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp +from triton_kernels.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx +from triton_kernels.numerics import InFlexData +from triton_kernels.routing import routing +from triton_kernels.target_info import is_hip, get_cdna_version +from dataclasses import dataclass + +if torch.cuda.is_available() and not is_hip(): + from triton._C.libtriton import nvidia + cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + cublas = nvidia.cublas.CublasLt(cublas_workspace) +else: + cublas = None + + +def _query_gpu_specs(): + import subprocess + if is_hip(): + cmd = ["rocm-smi", "--showproductname", "-d=0", "--csv"] + output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip() + model = output.splitlines()[1].split(",")[2] + if model in ["0x74a9", "0x74a1"]: + name = "AMD Instinct MI300X" + elif model == "0x74a5": + name = "AMD Instinct MI325X" + else: + name = "AMD" + else: + cmd = ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader", "-i=0"] + output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip() + name = output.splitlines()[0] + + gpu_specs = { + "NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35}, + "HGX GB200": {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0}, + "AMD Instinct MI300X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 5.3}, + "AMD Instinct MI325X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 6.0}, + } + return gpu_specs.get(name) + + +SPECS = _query_gpu_specs() + + +def quantize(w, dtype, dev, **opt): + if dtype == "bf16": + wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2) + return wq, InFlexData(), MicroscalingCtx() + elif dtype == "fp8": + fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 \ + else torch.float8_e4m3fnuz + wq = w.to(fp8e4_dtype).transpose(-1, -2).contiguous().transpose(-1, -2) + return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), \ + MicroscalingCtx() + else: + assert dtype == "mx4", f"{dtype=}" + swizzle_mx_scale = opt["swizzle_mx_scale"] + swizzle_axis = 2 if swizzle_mx_scale else None + w = w.to(torch.bfloat16) + w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis) + return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_mx=swizzle_mx_scale, + actual_weight_scale_shape=weight_scale_shape) + + +@dataclass +class PerfData: + time: float + flops: float + bytes: float + + @property + def tflops(self): + return self.flops / self.time * 1e-3 + + @property + def tbps(self): + return self.bytes / self.time * 1e-3 + + @property + def opint(self): + # operational intensity + assert self.bytes > 0 + return self.flops / self.bytes + + @property + def util(self) -> float: + if SPECS is None: + return 0.0 + + peak_flops = max(SPECS["MAX_TFLOPS8"], SPECS.get("MAX_TFLOPS16", 0)) + min_t_flop = self.flops / peak_flops * 1e-3 # ns → µs + min_t_bw = self.bytes / SPECS["MAX_TBPS"] * 1e-3 + return max(min_t_flop, min_t_bw) / self.time + + +def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name): + assert n_expts_tot % EP == 0 + assert dim2 % TP == 0 + dev = "cuda" + + # input + # weights + wg = torch.randn((dim1, n_expts_tot), device=dev) + w1 = torch.randn((n_expts_tot // EP, dim1, dim2 // TP), device=dev) + w2 = torch.randn((n_expts_tot // EP, dim2 // TP // 2, dim1), device=dev) + # biases + bg = torch.randn((n_expts_tot, ), device=dev) + b1 = torch.randn((n_expts_tot // EP, dim2 // TP), device=dev) + b2 = torch.randn((n_expts_tot // EP, dim1), device=dev) + + # -- numerics -- + optg = dict() + opt1 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict() + opt2 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict() + wg, wg_flex, wg_mx = quantize(wg, "bf16", dev, **optg) + w1, w1_flex, w1_mx = quantize(w1, w_dtype, dev, **opt1) + w2, w2_flex, w2_mx = quantize(w2, w_dtype, dev, **opt2) + pcg = PrecisionConfig(mx_ctx=wg_mx, flex_ctx=FlexCtx(rhs_data=wg_flex)) + pcs = triton_kernels.swiglu.PrecisionConfig(limit=1.0) + pc1 = PrecisionConfig(mx_ctx=w1_mx, flex_ctx=FlexCtx(rhs_data=w1_flex)) + pc2 = PrecisionConfig(mx_ctx=w2_mx, flex_ctx=FlexCtx(rhs_data=w2_flex)) + + # -- benchmark -- + fpath = Path(f"logs/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/profiles/batch-{batch}.hatchet") + fpath.parent.mkdir(parents=True, exist_ok=True) + x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype] + # special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz + if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3: + x_dtype = torch.float8_e4m3fnuz + + x = torch.randn((batch, dim1), device=dev) + xg = x.to(wg.dtype if n_expts_tot > 1 else x_dtype) + x = x.to(x_dtype) + # run layer + proton.start(str(fpath.with_suffix('')), hook="triton") + for i in range(100): + if n_expts_tot > 1: + logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + rdata, gather_indx, scatter_indx = routing(logits, n_expts_act, simulated_ep=EP) + else: + rdata, gather_indx, scatter_indx = None, None, None + x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1) + x = triton_kernels.swiglu.swiglu(x, 1.0, pcs, routing_data=rdata) + x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2) + proton.finalize() + + # -- analyze -- + with open(f"{fpath}") as fd: + data = json.load(fd) + # TODO: this will be broken if kernels use scopes themselves + # compute useful (a.k.a. matmul) bytes and flops + matmuls = [ + x for x in data[0]["children"] if "_matmul" in x["frame"]["name"] and "metadata" not in x["frame"]["name"] + ] + bytes = sum([x["metrics"]["bytes"] for x in matmuls]) + flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]} + flops = sum([flops[w] for w in [8, 16]]) + # compute total time (incl. "not useful" work) + # TODO: proton should really be recording that in the json instead of + # relying on the user to aggregate + time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"]) + return PerfData(time, flops, bytes) + + +def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="", + verbose=True): + from itertools import chain + from bisect import bisect_left + batches = list(chain(*[range(*r) for r in batch_ranges])) + # collect performance data + perfs = [] + print(f"Benchmarking {name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})...") + print("===============================================================") + for batch in batches: + perfs += [bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name)] + if verbose: + print(f"Batch: {batch}; Util: {perfs[-1].util}; TFLOPS: {perfs[-1].tflops}; TBPS: {perfs[-1].tbps}") + print("===============================================================") + # machine limits + fig, ax = plt.subplots(figsize=(7, 5), dpi=120) + ax.set_xlabel("batch size (toks/expt)") + ax.set_ylabel("performance [TFLOP/s]") + ax.set_title("roofline") + # add a tiny margin so points are not flush with the frame + xs = [batch * n_expts_act / n_expts_tot for batch in batches] + perf = [p.tflops for p in perfs] + xmin, xmax = min(xs), max(xs) + dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0 + ax.set_xlim(xmin - dx, xmax + dx) + ax.set_ylim(100, SPECS["MAX_TFLOPS8"] + 500) + # plot roofline + max_tbps = SPECS["MAX_TBPS"] + max_tflops = SPECS["MAX_TFLOPS8"] + opints = [p.opint for p in perfs] + knee = bisect_left(opints, max_tflops / max_tbps) - 1 + x_bw, x_comp = xs[:knee], xs[knee:] + y_bw = [op * max_tbps for op in opints[:knee]] + y_comp = [max_tflops] * len(x_comp) + ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.0f} TB/s)") + ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)") + # plot data + ax.scatter(xs, perf, marker="+") + ax.legend(frameon=False, loc="lower right") + ax.grid(True, which="both", ls=":", lw=0.5) + fig.tight_layout() + fpath = Path(f"logs/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/roofline.png") + plt.savefig(fpath) + + +if __name__ == "__main__": + has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4 + if SPECS is None: + print("Current GPU has no specs provided, utilization is N/A") + batch_ranges = [(1024, 32768, 1024)] + dense_dtypes = ["fp8", "fp8"] + quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"] + roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense") + roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense") + roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick") + roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick") diff --git a/kernel-microbench/tk/pyproject.toml b/kernel-microbench/tk/pyproject.toml new file mode 100644 index 0000000..c7a45ae --- /dev/null +++ b/kernel-microbench/tk/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "triton_kernels" +version = "1.0.0" +dependencies = ["torch", "numpy", "pytest"] + +[build-system] +requires = ["setuptools>=64.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["triton_kernels*"] diff --git a/kernel-microbench/tk/tests/__init__.py b/kernel-microbench/tk/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel-microbench/tk/tests/conftest.py b/kernel-microbench/tk/tests/conftest.py new file mode 100644 index 0000000..90767c8 --- /dev/null +++ b/kernel-microbench/tk/tests/conftest.py @@ -0,0 +1,10 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cuda") + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/kernel-microbench/tk/tests/test_compaction.py b/kernel-microbench/tk/tests/test_compaction.py new file mode 100644 index 0000000..4e6c31e --- /dev/null +++ b/kernel-microbench/tk/tests/test_compaction.py @@ -0,0 +1,28 @@ +import pytest +import torch +from triton_kernels.compaction import compaction, compaction_torch + + +@pytest.mark.parametrize("n_tokens, n_cols, k, p", [ + (8192, 64, 4, 0.5), + (8192, 64, 4, 1.0), + (131, 128, 16, 0.6), + (496, 128, 16, 0.), +]) +def test_compaction(n_tokens, n_cols, k, p, device): + yi = torch.rand((n_tokens, n_cols), device=device).argsort(dim=-1) + yi = yi[:, :k].to(torch.int32) + yv = torch.randn((n_tokens, k), dtype=torch.bfloat16, device=device) + # "drop" indices from yi with probability `p` + mask = torch.zeros((n_tokens, n_cols), dtype=torch.int32, device=device) + keep = (torch.rand(yi.shape, device=device) < p) + if keep.any(): + rows = torch.arange(yi.size(0), device=device).unsqueeze(1).expand_as(yi) + mask[rows[keep], yi[keep]] = 1 + chunks = mask.view(*mask.shape[:-1], -1, 32) + weights = (1 << torch.arange(32, dtype=torch.int32, device=device)) + bitmask = (chunks.int() * weights).sum(dim=-1) + yv_ref, yi_ref = compaction_torch(yv, yi, bitmask) + yv_tri, yi_tri = compaction(yv, yi, bitmask) + assert torch.all(yi_ref == yi_tri) + assert torch.all(yv_ref == yv_tri) diff --git a/kernel-microbench/tk/tests/test_matmul.py b/kernel-microbench/tk/tests/test_matmul.py new file mode 100644 index 0000000..ee3d0b1 --- /dev/null +++ b/kernel-microbench/tk/tests/test_matmul.py @@ -0,0 +1,321 @@ +from dataclasses import dataclass, fields +import pytest +import torch +from typing import Union +# benchmarking utilities +# routing utilities +from triton_kernels.routing import routing +# matmul utilities +import triton_kernels.matmul_ogs_details.opt_flags as opt_flags +from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx +from triton_kernels.matmul_ogs import can_use_persistent_tma +from triton_kernels.matmul_ogs import matmul_ogs, matmul_ogs_torch +# numerics utilities +from triton_kernels.numerics import InFlexData, OutFlexData +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp +# testing utilities +from triton_kernels.testing import assert_close, compute_actual_scale +# target-specific utilities +from triton_kernels.target_info import is_hip + +# --------------- +# initialize data +# --------------- + + +def alloc_rand(shape, device, dtype, requires_grad=True): + if dtype.itemsize == 1: + tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) + return tmp.to(dtype).requires_grad_(requires_grad) + return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + +def alloc_rand_like(x): + return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) + + +def mask_indx(idx, n_expts_act): + idx.src_indx[idx.dst_indx[-n_expts_act:]] = -1 + idx.dst_indx[-n_expts_act:] = -1 + return idx + + +def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"): + logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True) + routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards) + routing_data.gate_scal = None + gather_idx = gather_idx if do_gather else None + scatter_idx = scatter_idx if do_scatter else None + return m, routing_data, gather_idx, scatter_idx + + +def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype, + has_y_gammas, requires_grad=True, device="cuda"): + torch.manual_seed(0) + assert mode in {'batched', 'ragged'} + in_m = m * (n_expts_act if gindx is None else 1) + shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k) + x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad) + w = alloc_rand((n_expts_tot // n_expt_shards, k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad) + bias = alloc_rand((n_expts_tot // n_expt_shards, n), device=device, dtype=torch.float32, + requires_grad=requires_grad) + gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) + gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) + gs0 = gs0.detach().requires_grad_(requires_grad) + gs1 = gs1.detach().requires_grad_(requires_grad) + if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): + gs0 = None + gs1 = None + return x, w, bias, gs0, gs1 + + +# --------------- +# numerics stuff +# --------------- + + +def init_precision(out_dtype, act_use_flexpoint, weight_use_flexpoint, n_expts_tot=1, mx_ctx=MicroscalingCtx(), + device="cuda"): + # flexpoint + make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) + + ([val0] + if n_expts_tot % 2 else []), dtype=torch.float32, device=device) + make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device) + in_flex_data = lambda scale, use_flex: InFlexData(dtype=torch.float8_e5m2, scale=make_scalar(scale) + ) if use_flex else InFlexData() + in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=torch.float8_e5m2, scale=make_tensor( + scale0, scale1)) if use_flex else InFlexData() + out_flex_data = lambda scale, use_flex: OutFlexData(dtype=torch.float8_e5m2, expected_scale=make_scalar( + scale), actual_scale=make_scalar(0), checksum_scale=make_scalar(0)) if use_flex else OutFlexData() + flex_ctx = FlexCtx( + lhs_data=in_flex_data(1.25, act_use_flexpoint), + rhs_data=in_flex_edata(1.50, 1.25, weight_use_flexpoint), + out_data=out_flex_data(4.00, act_use_flexpoint), + ) + return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0, + mx_ctx=mx_ctx, out_dtype=out_dtype) + + +def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config): + flex_ctx = precision_config.flex_ctx + + def apply(x, scale): + if scale is None: + return x.clone().detach().requires_grad_(True) + elif scale.numel() == 1: + return (x.float() * scale).detach().requires_grad_(True) + else: + assert x.ndim == 3 + assert scale.numel() == x.shape[0] + return (x.float() * scale[:, None, None]).detach().requires_grad_(True) + + return ( + apply(x_tri, flex_ctx.lhs_data.scale), + apply(w_tri, flex_ctx.rhs_data.scale), + apply(bias_tri, None), + None if gs0_tri is None else apply(gs0_tri, None), + None if gs1_tri is None else apply(gs1_tri, None), + ) + + +def dtype_str_to_torch(dtype_str: str) -> torch.dtype: + return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) + + +# --------------- +# unit tests +# --------------- + + +@dataclass +class Case: + m: int + n: int + k: int + mode: str + act_dtype_str: str + weight_dtype_str: str + n_expts_tot: int = 1 + n_expts_act: int = 1 + n_expt_shards: int = 1 + split_k: int = 1 + swizzle_mx_scale: bool = False + epilogue_subtile: Union[bool, None] = None + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Non-mx types: + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4), + Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), + Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1), + Case(16, 256, 256, "batched", "float16", "float16", 5, 1), + Case(16, 256, 256, "ragged", "float16", "float16", 3, 1), + Case(256, 256, 256, "ragged", "float16", "float16", 4, 1), + Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3), + Case(300, 400, 400, "batched", "float16", "float16", 5, 1), + Case(300, 400, 400, "ragged", "float16", "float16"), + Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"), + Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=False), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=True), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2), + Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1), + Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2), + Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), + # mx types: + Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), + Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), + Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), + Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4), + Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, swizzle_mx_scale=True), + Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), + Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), + Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, swizzle_mx_scale=True), + Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, swizzle_mx_scale=True), + Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, swizzle_mx_scale=False), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, swizzle_mx_scale=False), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, swizzle_mx_scale=True), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, swizzle_mx_scale=False), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, swizzle_mx_scale=True), + Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, swizzle_mx_scale=False), + Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, swizzle_mx_scale=True), + Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, swizzle_mx_scale=False), + Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, swizzle_mx_scale=True), + Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, swizzle_mx_scale=False), + Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, swizzle_mx_scale=True), + # AMD + Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), + Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2), + ] + ], +) +@pytest.mark.parametrize("block_m", [16, 128]) +@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ + (False, False, False), + (True, False, False), + (False, True, False), + (True, True, False), + (True, True, True), +]) +@pytest.mark.parametrize("has_y_gammas", [False, True]) +@pytest.mark.parametrize("is_persistent", [False, True]) +def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot, + n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, swizzle_mx_scale, + epilogue_subtile, device): + # TODO: remove when Triton FP8 supports proper RTNE + if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("Float8 not tested on A100") + if "float8_e4m3fnuz" in weight_dtype_str and not is_hip(): + pytest.skip("float8_e4m3fnuz only tested on HIP platforms") + if "mx" in weight_dtype_str and is_hip(): + pytest.skip("mxfloat* only tested on CUDA platforms") + if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10: + pytest.skip("float16 x mx not supported with cuda capability >= 10") + if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10: + pytest.skip("float8 x mx not supported with cuda capability < 10") + if fused_scatter and split_k > 1: + pytest.skip("fused scatter scratchpad not supported with split_k") + + torch.manual_seed(0) + + block_k = None + if is_persistent and weight_dtype_str.startswith("mx") and not torch.cuda.get_device_capability()[0] >= 10: + # Override block_k for testing correctness. The default is temporarily 128 for + # performance reasons which doesn't work with persistent matmul. + # TODO: revisit when Triton is better for H100 + MXFP4 + block_k = 256 + + constraints = { + "block_m": block_m, + "block_k": block_k, + "split_k": split_k, + "fused_scatter": fused_scatter, + "is_persistent": is_persistent, + "epilogue_subtile": epilogue_subtile, + } + opt_flags.update_opt_flags_constraints(constraints) + + is_mixed_input = act_dtype_str != weight_dtype_str + if weight_dtype_str.startswith("mx"): + weight_dtype_str = weight_dtype_str[2:] + + test_bwd = False + weight_dtype = dtype_str_to_torch(weight_dtype_str) + act_dtype = dtype_str_to_torch(act_dtype_str) + act_is_float8 = act_dtype.itemsize == 1 + weight_is_float8 = weight_dtype.itemsize == 1 + precision_opt = init_precision(act_dtype, act_is_float8, weight_is_float8 and not is_mixed_input, + n_expts_tot // n_expt_shards, device=device) + # precision_opt.x_pad_trans_requires_flexpoint = False + if mode == "ragged": + m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, + device=device) + else: + rdata = gindx = sindx = None + x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, + n_expt_shards, mode, act_dtype, # + torch.bfloat16 if is_mixed_input else weight_dtype, + has_y_gammas, requires_grad=test_bwd, device=device) + x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt) + + if is_mixed_input: + swizzle_axis = 2 if swizzle_mx_scale else None + w_tri, mx_scales_tri, weight_scale_shape = downcast_to_mxfp(w_tri, weight_dtype, axis=1, + swizzle_axis=swizzle_axis) + w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=1, swizzle_axis=swizzle_axis) + + precision_opt.mx_ctx = MicroscalingCtx(weight_scale=mx_scales_tri, swizzle_mx=swizzle_mx_scale, + actual_weight_scale_shape=weight_scale_shape) + + if is_persistent and not can_use_persistent_tma(x_tri, w_tri, gindx, precision_opt): + pytest.skip("persistent TMAs not supported for this test") + + if w_tri.shape[0] == 1: + # Test the case when weight has dim 2, i.e., shape (K, N). + w_tri = w_tri.squeeze(0).detach().requires_grad_() + w_ref = w_ref.squeeze(0).detach().requires_grad_() + + if mode == "batched": + rdata, gindx, sindx = None, None, None + flex = precision_opt.flex_ctx + # triton + tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref) + # If split_k > 1, then the intermediate tensor is fp32. + sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1 + sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1 + y_scale = flex.out_data.expected_scale if act_is_float8 else 1 + + def round_x(x, idx): + return x.to(act_dtype).to(torch.float32) if sep_gather else x + + round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y + ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, # + rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref) + scale = lambda val, scal: val if scal is None else val / scal + if n_expt_shards > 1: + if not do_scatter: + n_rows = rdata.expt_hist.sum() + assert n_rows > 0 + ref_y = ref_y[:n_rows] + tri_y = tri_y[:n_rows] + assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y) + + if act_is_float8: + tri_y_scale = flex.out_data.actual_scale.clone() + ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype) + assert (ref_y_scale - + tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}" diff --git a/kernel-microbench/tk/tests/test_routing.py b/kernel-microbench/tk/tests/test_routing.py new file mode 100644 index 0000000..41258a6 --- /dev/null +++ b/kernel-microbench/tk/tests/test_routing.py @@ -0,0 +1,93 @@ +import pytest +import torch +from triton_kernels.routing import routing, routing_torch +from triton_kernels.testing import assert_close +from triton_kernels.matmul_ogs_details.metadata import compute_metadata +from triton_kernels.testing import assert_equal + + +def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): + # the reference implementation and the triton implementation do not tie-break experts the same way + randbits = [torch.randperm(n_expts_tot) for _ in range(n_tokens)] + x = [(-1)**i * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(dtype)) for i, bits in enumerate(randbits)] + return torch.stack(x).to(device=device) + + +def ref_expt_data(routing_data, n_gates, block_m): + hist = routing_data.expt_hist + n_expts_tot = routing_data.n_expts_tot + blks = (hist + block_m - 1) // block_m # matmul blocks needed + tsum = torch.cumsum(hist, dim=0) # prefix sum of tokens + bsum = torch.cumsum(blks, dim=0) # prefix sum of blocks + # Get the max number of matmul blocks of size d_tile needed (and is launched with). + # This assumes the worst distribution of all experts with one token except for one that has the rest. + if n_gates <= n_expts_tot: + grid_m = n_gates + else: + # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1 + # ceil_div(x, y): -(-x // y) + grid_m = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // block_m) + bloc_data = -torch.ones(grid_m, dtype=torch.int32) + # compute data required to drive ragged batch matmul + for e in range(n_expts_tot): + offset = bsum[e - 1] if e else 0 + for b in range(blks[e]): + bloc_data[offset + b] = (b << 16) + e + + expt_data = torch.zeros(n_expts_tot * 3 + 2 + grid_m, dtype=torch.int32, device=hist.device) + expt_data[:n_expts_tot] = routing_data.expt_hist + expt_data[n_expts_tot + 1:n_expts_tot * 2 + 1] = tsum + expt_data[n_expts_tot * 2 + 2:n_expts_tot * 3 + 2] = bsum + expt_data[n_expts_tot * 3 + 2:] = bloc_data + return expt_data + + +@pytest.mark.parametrize("n_tokens", [371, 255, 256, 8192, 1023, 1024]) +@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4), (1500, 8)]) +@pytest.mark.parametrize("block_m", [64, 128]) +def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, device): + torch.manual_seed(2) + tri_logits = init_data(n_tokens, n_expts_tot, device=device).detach() + ref_logits = tri_logits.clone() + ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act) + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) + ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m) + tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m).buffer + + assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal, 2e-2, 4e-3) + assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist) + assert_equal(ref_metadata, tri_metadata) + assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot + assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act + + def _assert_indx_equal(ref, tri): + assert_equal(ref, tri[:len(ref)]) + assert torch.all(tri[len(ref):] == -1) + + _assert_indx_equal(ref_gather.src_indx, tri_gather.src_indx) + _assert_indx_equal(ref_gather.dst_indx, tri_gather.dst_indx) + _assert_indx_equal(ref_scatter.src_indx, tri_scatter.src_indx) + _assert_indx_equal(ref_scatter.dst_indx, tri_scatter.dst_indx) + + +def bench_routing(): + import triton.profiler as proton + n_tokens = 8192 + block_m = 128 + n_expts_tot, n_expts_act = 128, 4 + tri_logits = init_data(n_tokens, n_expts_tot) + proton.start("routing") + proton.activate() + for i in range(100): + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) + tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m) # noqa: F841 + proton.finalize() + try: + import os + os.system("proton-viewer -m time/ms routing.hatchet") + except Exception: + pass + + +if __name__ == "__main__": + bench_routing() diff --git a/kernel-microbench/tk/tests/test_swiglu.py b/kernel-microbench/tk/tests/test_swiglu.py new file mode 100644 index 0000000..b3ab353 --- /dev/null +++ b/kernel-microbench/tk/tests/test_swiglu.py @@ -0,0 +1,42 @@ +from triton_kernels.routing import routing_torch +from triton_kernels.swiglu import swiglu, swiglu_torch, PrecisionConfig +from triton_kernels.testing import assert_close +import torch +import pytest + +from .test_routing import init_data as init_routing_data + +# --------------- +# initialize data +# --------------- + + +def alloc_rand(shape, device, dtype, requires_grad=True): + if dtype.itemsize == 1: + tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) + return tmp.to(dtype).requires_grad_(requires_grad) + return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + +# --------------- +# unit tests +# --------------- + + +@pytest.mark.parametrize("M, N", [(1311, 4352)]) +@pytest.mark.parametrize("limit", [1e-2, 10]) +def test_op(M, N, limit, device, alpha=0.5): + torch.manual_seed(2) + # initialize expert data + n_expts_tot = 6 + n_expts_act = 2 + logits = init_routing_data(M, n_expts_tot).detach() + routing_data, _, _ = routing_torch(logits, n_expts_act) + n_tokens = routing_data.expt_hist.sum() + + # initialize data + x = alloc_rand([n_tokens, N], device=device, dtype=torch.bfloat16) + precision_config = PrecisionConfig(limit=limit) + tri_y = swiglu(x, alpha, precision_config, routing_data) + ref_y = swiglu_torch(x, alpha, precision_config) + assert_close(tri_y, ref_y) diff --git a/kernel-microbench/tk/triton_kernels/__init__.py b/kernel-microbench/tk/triton_kernels/__init__.py new file mode 100644 index 0000000..2e108d2 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/__init__.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass +class Bitmatrix: + data: "torch.Tensor" # noqa: F821 + shape: tuple[int] diff --git a/kernel-microbench/tk/triton_kernels/compaction.py b/kernel-microbench/tk/triton_kernels/compaction.py new file mode 100644 index 0000000..b4849d1 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/compaction.py @@ -0,0 +1,69 @@ +import torch +from .compaction_details._masked_compaction import _masked_compaction +from triton_kernels import Bitmatrix + + +def compaction(yv, yi, bitmask, sentinel=-1): + """ + Return compacted copies of *yv* and *yi* based on a per-row bitmask. + + Only the elements whose index appears among the active bits of *bitmask* + are kept; the rest are replaced by *sentinel*. Kept elements preserve + their original left-to-right order. + + Parameters + ---------- + yv : torch.Tensor, shape (B, K) + Values tensor. + yi : torch.Tensor, shape (B, K), dtype torch.long + Integer indices (0 ≤ index < 32) associated with *yv*. + bitmask : torch.Tensor, shape (B,) **or** (B, 32) + Per-row mask of active indices. See the in-place version for details. + sentinel : int, default -1 + Value written into dropped positions of the returned tensors. + + Returns + ------- + (yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K) + New tensors with the same dtype/device as the inputs. + + """ + + n_rows, n_cols = yi.shape + ret_yv = torch.empty_like(yv) + ret_yi = torch.empty_like(yi) + if isinstance(bitmask, Bitmatrix): + bitmask = bitmask.data + + _masked_compaction[(n_rows, )]( + yv, yi, bitmask, bitmask.stride(0), # inputs + ret_yv, ret_yi, # outputs + sentinel, # sentinel + K=n_cols # constants + ) + return ret_yv, ret_yi + + +def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1): + """ + reference implementation of `masked_compact` + """ + B, K = yi.shape + device = yi.device + # Expand bitmask to a boolean matrix of active bits (B, 32) + w = (1 << torch.arange(32, device=device, dtype=bitmask.dtype)) + bits = (bitmask.unsqueeze(-1) & w) != 0 + mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1) + # For every yi element decide whether it should be kept + keep = mask.gather(1, yi.long()) + # Build a stable permutation that brings all "keep" items forward + # False→0, True→1 ==> invert so kept==0, dropped==1, then argsort + order = (~keep).to(torch.int).argsort(dim=1, stable=True) + # Re‑order tensors according to above permutation + yi_sorted = yi.gather(1, order) + yv_sorted = yv.gather(1, order) + # fill relevant positions with sentinel + keep_sorted = keep.gather(1, order) + yi_sorted[~keep_sorted] = sentinel + yv_sorted[~keep_sorted] = sentinel + return yv_sorted, yi_sorted diff --git a/kernel-microbench/tk/triton_kernels/compaction_details/_masked_compaction.py b/kernel-microbench/tk/triton_kernels/compaction_details/_masked_compaction.py new file mode 100644 index 0000000..2d83fa0 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/compaction_details/_masked_compaction.py @@ -0,0 +1,19 @@ +import triton +import triton.language as tl + + +@triton.jit +def _masked_compaction(Yv, Yi, BitMask, stride_bm, RetYv, RetYi, sentinel, K: tl.constexpr): + pid_m = tl.program_id(0) + yv = tl.load(Yv + pid_m * K + tl.arange(0, K)) + yi = tl.load(Yi + pid_m * K + tl.arange(0, K)) + div = yi // 32 + rem = yi % 32 + active_bits = (tl.load(BitMask + pid_m * stride_bm + div) >> rem) & 1 + exc_cumsum = tl.cumsum(active_bits, 0) - active_bits + rev_arange = tl.where(active_bits, 0, K - 1 - tl.arange(0, K)) + write_indx = exc_cumsum + rev_arange + yv = tl.where(active_bits, yv, sentinel) + yi = tl.where(active_bits, yi, sentinel) + tl.store(RetYv + pid_m * K + write_indx, yv) + tl.store(RetYi + pid_m * K + write_indx, yi) diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs.py b/kernel-microbench/tk/triton_kernels/matmul_ogs.py new file mode 100644 index 0000000..72a9d8a --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs.py @@ -0,0 +1,679 @@ +from dataclasses import dataclass +import itertools +import sys +import torch +import triton +# utilities +from triton_kernels import target_info +from triton_kernels.numerics import InFlexData, OutFlexData +from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx +# details +from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx +from .matmul_ogs_details._matmul_ogs import _matmul_ogs +from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn +from .matmul_ogs_details._finalize_matmul import _finalize_matmul +from .matmul_ogs_details.opt_flags import make_opt_flags +from .matmul_ogs_details.metadata import compute_metadata +from .matmul_ogs_details.fast_contiguous import fast_contiguous +from .specialize import specialize + + +@dataclass +class Epilogue: + name: str + fn: "triton.runtime.jit.JITFunction" + fn_arg_names: tuple[str] + fn_arg_values_matmul: tuple[object] + fn_arg_values_finalize: tuple[object] + fn_arg_do_not_specialize: tuple[str] = tuple() + is_expensive: bool = False + + +_kernels = dict() + + +def get_kernels(epilogue: Epilogue): + global _kernels + if epilogue.name in _kernels: + return _kernels[epilogue.name] + spec_constants = {"EPILOGUE_FN": epilogue.fn} + spec_tuples = {"epilogue_fn_args": epilogue.fn_arg_names} + do_not_specialize = epilogue.fn_arg_do_not_specialize + import types + + module = types.ModuleType(f"matmul_ogs_{epilogue.name}") + sys.modules[module.__name__] = module + module._finalize_matmul = specialize(_finalize_matmul, module, spec_constants, spec_tuples, + do_not_specialize=do_not_specialize) + module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples, + do_not_specialize=do_not_specialize) + module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples, + do_not_specialize=do_not_specialize) + _kernels[epilogue.name] = module + return module + + +# ----------------------------------------------------------------------------- +# Matrix Multiplication + Outer Gather/Scatter +# ----------------------------------------------------------------------------- + + +def can_overflow_int32(tensor: torch.Tensor): + max_int32 = (1 << 31) - 1 + offset = 0 + for i in range(tensor.ndim): + offset += (tensor.shape[i] - 1) * tensor.stride(i) + return offset > max_int32 + + +def should_upcast_indices(*args): + return any(tensor is not None and can_overflow_int32(tensor) for tensor in args) + + +# --------------------- +# Numerics +# --------------------- + +# fmt: off + +@dataclass(frozen=True) +class MicroscalingCtx: + # This interprets the scales as E8M0 tensors + # Packed fp4s (e2m1) are stored as torch.uint8 tensors. + # Not used for now, inserted here to make space in the APIs. + act_scale: torch.Tensor | None = None + weight_scale: torch.Tensor | None = None + + swizzle_mx: bool = False # Whether the weight scales are stored in swizzled 5D layout + actual_weight_scale_shape: tuple | None = None # Actual weight scales shape, without padding + + def __post_init__(self): + assert self.act_scale is None, "Activation scale not supported yet" + if self.weight_scale is None: + return + + if self.actual_weight_scale_shape is None: + object.__setattr__(self, "actual_weight_scale_shape", self.weight_scale.shape) + + # Validate the scale tensor data type + if self.weight_scale.dtype != torch.uint8: + raise TypeError(f"Weight scale must be uint8. Got {self.weight_scale.dtype}") + + # Validate scale tensor dimensions + if self.weight_scale.ndim != 3: + raise ValueError( + f"Weight scale must be 3D (experts, in_dim // BLOCK_SIZE, out_dim). Got {self.weight_scale.shape}" + ) + + def check_inputs(self, weights: torch.Tensor) -> None: + if self.weight_scale is None: + return + + valid_weight_types = {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn} + # Validate weights data type + if weights.dtype not in valid_weight_types: + raise TypeError(f"Weights must be one of {valid_weight_types}. Got {weights.dtype}") + + # Validate weights tensor dimensions + if weights.ndim != 3: + raise ValueError(f"Weights must be 3D (experts, in_dim, out_dim). Got {weights.shape}") + + # Validate shapes + weight_scale_shape = self.actual_weight_scale_shape + if weights.shape[0] != weight_scale_shape[0] or weights.shape[2] != weight_scale_shape[2]: + raise ValueError( + f"Weights and scale must have the same number of experts and output dimensions. " + f"Got weights experts: {weights.shape[0]}, scale experts: {weight_scale_shape[0]}, " + f"weights out_dim: {weights.shape[2]}, scale out_dim: {weight_scale_shape[2]}" + ) + + k_dim = self.get_packed_tensor_logical_shape(weights)[1] + rounded_k_dim = (k_dim + 31) // 32 * 32 + block_size = rounded_k_dim // weight_scale_shape[1] + if block_size != 32: + raise ValueError(f"Block size must be 32. Got {block_size}") + + def compute_strides(self): + if self.weight_scale is not None: + # Check expected properties of the weights. + if self.swizzle_mx: + mxE, mxK, mxN = self.weight_scale.shape + + # Compute strides of the 5D swizzled tensor. + swizzled_shape = (mxE, mxN // 128, mxK // 4, 32, 4, 4) + s5 = 1 + s4 = swizzled_shape[5] * s5 # 4 * 1 = 4 + s3 = swizzled_shape[4] * s4 # 32 * 4 = 128 + s2 = swizzled_shape[3] * s3 # 4 * 128 = 512 + s1 = swizzled_shape[2] * s2 # (mxK//4) * 512 + s0 = swizzled_shape[1] * s1 # (mxN//128) * ((mxK//4)*512) + mx_scale_stride_e, mx_scale_stride_n, mx_scale_stride_k = s0, s1, s2 + else: + mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n = self.weight_scale.stride() + else: + mx_scale_stride_e = mx_scale_stride_k = mx_scale_stride_n = 0 + return mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n + + + def get_packed_tensor_logical_shape(self, tensor: torch.Tensor): + k_dim = tensor.shape[1] + if tensor.dtype == torch.uint8: + # Assume 2 fp4s packed into a byte + k_dim *= 2 + return tensor.shape[0], k_dim, tensor.shape[2] + +@dataclass(frozen=True) +class FlexCtx: + lhs_data: InFlexData = InFlexData() + rhs_data: InFlexData = InFlexData() + out_data: OutFlexData = OutFlexData() + +@dataclass +class PrecisionConfig: + max_num_imprecise_acc: int = None + allow_tf32: bool = True + flex_ctx: FlexCtx = FlexCtx() + acc_scale: int = 1.0 + flexpoint_saturate_inf: bool = False + report_quantization_err_fn: callable = None + + mx_ctx: MicroscalingCtx = MicroscalingCtx() + out_dtype: torch.dtype = None + enforce_bitwise_invariance: bool = False + + def __post_init__(self): + assert self.flex_ctx.rhs_data.scale is None or self.mx_ctx.weight_scale is None, "flex and mx_ctx cannot be used together" + +def mx_can_use_tma(mx_ctx: MicroscalingCtx): + mx_scale_stride_e, mx_scale_stride_n, mx_scale_stride_k = mx_ctx.compute_strides() + if mx_scale_stride_e * mx_ctx.weight_scale.element_size() % 16 != 0: + return False + + if mx_ctx.swizzle_mx: + # CHeck stride in bytes are multiples of 16. + return mx_scale_stride_n * mx_ctx.weight_scale.element_size() % 16 == 0 and mx_scale_stride_k * mx_ctx.weight_scale.element_size() % 16 == 0 + else: + # Check MX is either transposed or non-transposed, and with required stride. + return ( + (mx_scale_stride_n * mx_ctx.weight_scale.element_size() % 16 == 0 and mx_scale_stride_k == 1) or + (mx_scale_stride_k * mx_ctx.weight_scale.element_size() % 16 == 0 and mx_scale_stride_n == 1) + ) + +def can_use_persistent_tma(x, w, gather_indx, precision_config): + mx_ctx = precision_config.mx_ctx + return ( + # TMA requires CUDA 9.0, last dim contiguous, and multiple of 16-byte strides otherwise. + target_info.cuda_capability_geq(9, 0) and + (True if gather_indx is not None else + # Check strides of X. + x.stride(1) * x.element_size() % 16 == 0 and x.stride(2) == 1 + ) and ( + # Check W is either transposed or non-transposed, and with required stride. + (w.stride(1) * w.element_size() % 16 == 0 and w.stride(2) == 1) or + (w.stride(2) * w.element_size() % 16 == 0 and w.stride(1) == 1) + ) and ( + mx_ctx.weight_scale is None or mx_can_use_tma(mx_ctx) + ) and ( + # MFXP4 tma requires 128 elements on the inner dim. + # MFXP4 is represented as packed uint8. + w.dtype != torch.uint8 or w.shape[-1] % 128 == 0 + ) + # compiler crash ? + and (x.dtype.itemsize <= 1 or w.dtype != torch.uint8) + ) + +def can_use_fused_scatter(scatter_indx): + return scatter_indx is not None + +# --------------------- +# Preprocessing +# --------------------- + +@dataclass(frozen=True) +class PreprocessingFeatures: + w_want_n_major: bool + w_want_k_major: bool + swap_xw: bool + + def __post_init__(self): + assert not (self.w_want_k_major and self.w_want_n_major), "Cannot have both K-major and N-major" + +def init_preprocessing_features(w, precision_config, opt_flags): + mx_ctx = precision_config.mx_ctx + swap_xw = False # Whether or not to swap X and W operands to the tl.dot + w_want_k_major = False + w_want_n_major = False + if not target_info.cuda_capability_geq(10, 0): + # Hopper transpose. Reduction dimension must be contiguous. + if w.stride(1) != 1 and w.dtype.itemsize == 1: + w_want_k_major = True + + if target_info.cuda_capability_geq(10, 0): + swap_xw = mx_ctx.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent + if swap_xw: + w_want_k_major = True + # fp4 padded mode requires the contiguous dim size to be a multiple of 64 bytes. If it is K-major and does not + # meet the requirement, make the tensor N-major instead. + # But, don't do this if we're going to swap X and W in which case we would transpose W again. + if w.stride(1) == 1 and w.dtype == torch.uint8 and w.shape[1] % 64 != 0 and not swap_xw: + w_want_n_major = True + return PreprocessingFeatures(w_want_n_major, w_want_k_major, swap_xw) + + +def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features): + has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1 + if has_fused_scatter_scratchpad: + M = scatter_indx.src_indx.shape[0] + writeback_idxs = torch.empty((M,), dtype=torch.int32, device=x.device) + writeback_size = writeback_idxs.shape[0] + finalize_scatter_idxs = torch.zeros((M // routing_data.n_expts_act + M + 1,), dtype=torch.int32, device=x.device) + BLOCK_M=256 + _compute_writeback_idx[(triton.cdiv(M, BLOCK_M),)]( + writeback_idxs, + finalize_scatter_idxs, + scatter_indx.dst_indx, + scatter_indx.src_indx, + M // routing_data.n_expts_act, + M, + BLOCK_M=BLOCK_M, + N_EXPTS_ACT=routing_data.n_expts_act, + ) + elif scatter_indx is not None and routing_data.n_expts_act == 1: + writeback_idxs = scatter_indx.dst_indx + writeback_size = scatter_indx.dst_indx.shape[0] + finalize_scatter_idxs = None + else: + writeback_idxs, writeback_size, finalize_scatter_idxs = None, None, None + # some transposition variants aren't supported + if preprocessing_features.w_want_n_major: + w = fast_contiguous(w) + elif preprocessing_features.w_want_k_major: + w = fast_contiguous(w.transpose(-1, -2)).transpose(-1, -2) + # preprocess routing information and ptr lookup table + M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0] + expt_data = compute_metadata(routing_data, M, opt_flags.block_m) + return x, w, preprocessing_features.swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs, expt_data + + +# --------------------- +# Postprocessing +# --------------------- + + +@dataclass(frozen=True) +class PostprocessingFeatures: + finalize: bool + +def init_postprocessing_features(routing_data, scatter_indx, opt_flags): + finalize = (scatter_indx is not None and routing_data.n_expts_act > 1) or opt_flags.split_k > 1 + return PostprocessingFeatures(finalize) + +def apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_offs, num_indx, precision_config, routing_data, + postprocess_features, memory, epilogue): + out = memory["output"] + flex_ctx = precision_config.flex_ctx + if postprocess_features.finalize: + has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1 + if has_fused_scatter_scratchpad: + inp = memory["output"] + else: + inp = memory["scratchpad"]["matmul"] + if scatter_indx is not None: + assert inp.shape[1] == 1, "batched finalize scatter not supported" + n_final_rows = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act + scatter_src_indx = scatter_indx.src_indx + EXPT_PER_TOK = routing_data.n_expts_act + num_rows = None + else: + n_final_rows = inp.shape[1] * inp.shape[2] + scatter_src_indx = None + EXPT_PER_TOK = 1 + num_rows = num_indx or (None if expt_offs is None else expt_offs[-1]) + + if inp.dtype == torch.float32: + inp_flex = OutFlexData() + else: + inp_flex = precision_config.flex_ctx.out_data + + out_scatter = memory["output"] + out_scatter_flex = precision_config.flex_ctx.out_data + + N = inp.shape[3] + M = n_final_rows + warps_per_sm = 32 if target_info.is_hip() else 128 + + def compute_grid(BLOCK_N, num_warps): + num_pid = target_info.num_sms() * (warps_per_sm // num_warps) + if M < num_pid or target_info.is_hip(): + grid_n = triton.cdiv(N, BLOCK_N) + grid_m = min(M, max(1, triton.cdiv(num_pid, grid_n))) + else: + grid_m = min(M, num_pid) + grid_n = 1 + return (grid_m, grid_n) + + if inp.dtype.itemsize == 1: + candidates = [(1024, 1)] + else: + if target_info.is_hip(): + candidates = [(4096 // inp.dtype.itemsize, 2)] + else: + if inp.dtype.itemsize == 2: + candidates = [ + (4096 // inp.dtype.itemsize, 4), + (1024 // inp.dtype.itemsize, 1), + ] + else: + candidates = [ + (2048 // inp.dtype.itemsize, 4), + (1024 // inp.dtype.itemsize, 1), + ] + if precision_config.enforce_bitwise_invariance: + candidates = [candidates[0]] + + # sort by smallest grid_n so we share compute across a row + grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0] + STAGES = 1 if num_warps == 1 else min(triton.cdiv(triton.cdiv(N, BLOCK_N), grid[1]), 5) + + kernels = get_kernels(epilogue) + kernels._finalize_matmul[grid]( + flex_ctx.out_data.reinterpret(out_scatter), + *out_scatter_flex, + flex_ctx.out_data.reinterpret(inp), inp.stride(0), inp.stride(2), + inp_flex.expected_scale, + scatter_src_indx, finalize_scatter_idxs, + inp.shape[0], M, N, num_rows, + *epilogue.fn_arg_values_finalize, + EXPT_PER_TOK=EXPT_PER_TOK, + BLOCK_N=BLOCK_N, + STAGES=STAGES, + num_warps=num_warps, + flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf, + HAS_FUSED_SCRATCHPAD=has_fused_scatter_scratchpad, + ) + out = out_scatter + # trim unnecessary part of output + if has_fused_scatter_scratchpad: + # Discard scratchpad part. + # This still gives a contiguous tensor, because shape[0] > 1 only when + # batch mode is enabled, in which case this is a no-op (there's no scratchpad). + out = out[:, :, :n_final_rows, :] + return out + + +# --------------------- +# Allocation +# --------------------- + +@dataclass +class MatmulAllocation: + device: str + output: tuple[tuple[int], torch.dtype] + scratchpads: dict[str, tuple] + +def init_allocation(x, w, precision_config, routing_data, gather_indx, scatter_indx, opt_flags, + preprocessing_features, postprocessing_features): + # ---- output ------ + N = precision_config.mx_ctx.get_packed_tensor_logical_shape(w)[-1] + # by default - M is number of rows in the activations + M = x.shape[1] + # if the activations are gathered, then M is number of gather indices + if gather_indx is not None: + M = gather_indx.src_indx.shape[0] + # final output + if routing_data.n_expts_act == 1 or scatter_indx is None: + y_rows = M + elif opt_flags.fused_scatter: + # we need the scratchpad and the output to be contiguous in memory + Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows + y_rows = M + Mc + else: + Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows + y_rows = Mc + y_shape = (x.shape[0], y_rows, N) + out_dtype = precision_config.out_dtype or x.dtype + output = (y_shape, out_dtype) + # ---- scratchpad -----# + scratchpad = dict() + # if we need either standalone scatter or split-k, the matmul output will need post-processing + if postprocessing_features.finalize and (opt_flags.split_k > 1 or not opt_flags.fused_scatter): + dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype + scratchpad["matmul"] = ((opt_flags.split_k, x.shape[0], M, N), dtype) + return MatmulAllocation(x.device, output, scratchpad) + + +def apply_allocation(allocation: MatmulAllocation, output): + ret = dict() + if output is None: + output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1]) + else: + assert output.shape == allocation.output[0] + ret["output"] = output[None, :, :] + ret["scratchpad"] = { + k: torch.empty(v[0], device=allocation.device, dtype=v[1]) + for k, v in allocation.scratchpads.items() + } + return ret + +# ----------------------------------------------------------------------------- +# Triton Implementation +# ----------------------------------------------------------------------------- + +def matmul_ogs(x, w, bias, + routing_data: RoutingData | None = None, + gather_indx: GatherIndx | None = None, + scatter_indx: ScatterIndx | None = None, + precision_config: PrecisionConfig | None = None, + betas: torch.Tensor | None = None, + gammas: torch.Tensor | None = None, + out_alpha: float | None = None, + y: torch.Tensor | None = None, + epilogue: Epilogue | None = None, + ): + """ + Y[:, :] = 0. + for e in num_experts: + Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) + """ + + is_input_batched = x.ndim == 3 + if is_input_batched: + assert gather_indx is None, "gather not supported in batched mode" + assert scatter_indx is None, "scatter not supported in batched mode" + assert routing_data is None, "routing not supported in batched mode" + assert w.ndim == 3 and w.shape[0] == x.shape[0] + if precision_config is None: + precision_config = PrecisionConfig() + if epilogue is None: + epilogue = Epilogue("dflt", None, tuple(), tuple(), tuple(), False) + if w.ndim == 2: + w = w.view(1, w.shape[-2], w.shape[-1]) + if x.ndim == 2: + x = x.view(1, x.shape[-2], x.shape[-1]) + assert w.ndim == 3 + assert x.ndim == 3 + # unpack scales + mx_ctx = precision_config.mx_ctx + # determine shapes + M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0] + if routing_data is None: + routing_data = RoutingData(None, None, w.shape[0], 1) + batch_size = w.shape[0] if routing_data.expt_hist is None else 1 + n_expts_tot, K, N = mx_ctx.get_packed_tensor_logical_shape(w) + mx_ctx.check_inputs(w) + mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n = mx_ctx.compute_strides() + # compute optimization flags + out_dtype = precision_config.out_dtype or x.dtype + opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config, + M, N, K, routing_data, + can_use_persistent_tma(x, w, gather_indx, precision_config), + can_use_fused_scatter(scatter_indx), + epilogue.is_expensive, + ) + # compute grid size + if not is_input_batched: + grid_m = routing_data.n_blocks(M, opt_flags.block_m) + else: + grid_m = triton.cdiv(M, opt_flags.block_m) + grid_n = triton.cdiv(N, opt_flags.block_n) + assert n_expts_tot == routing_data.n_expts_tot + assert grid_m > 0 + assert x.dtype == w.dtype or mx_ctx.weight_scale is not None + # determine necessary pre/post processing + preprocessing_features = init_preprocessing_features(w, precision_config, opt_flags) + postprocessing_features = init_postprocessing_features(routing_data, scatter_indx, opt_flags) + # allocate output/scratchpad memory + allocation = init_allocation(x, w, precision_config, routing_data, gather_indx, scatter_indx, opt_flags, + preprocessing_features, postprocessing_features) + memory = apply_allocation(allocation, y) + # TMA descriptors require a global memory allocation + if opt_flags.is_persistent: + triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device)) + # Intermediate tensors and postprocess kernels for each situation + out0, out0_flex = memory["output"], precision_config.flex_ctx.out_data + if postprocessing_features.finalize: + if opt_flags.fused_scatter: + out0 = memory["output"] + else: + out0 = memory["scratchpad"]["matmul"] + out0_flex = OutFlexData() if out0.dtype == torch.float32 else precision_config.flex_ctx.out_data + # pre-processing + x, w, swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs, expt_data = apply_preprocessing_features( + x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features + ) + if expt_data.buffer is not None: + assert expt_data.buffer.shape[0] == 3*n_expts_tot + 2 + grid_m, \ + f"invalid expt_data, {expt_data.buffer.shape}, {n_expts_tot=}, {grid_m=}" + # matrix multiplication + n_cta = batch_size * grid_m * grid_n * opt_flags.split_k + n_cta = min(target_info.num_sms(), n_cta) if opt_flags.is_persistent else n_cta + flex = precision_config.flex_ctx + bias_stride = None if bias is None else bias.stride(0) + num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0] + kernels = get_kernels(epilogue) + (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)]( + flex.out_data.reinterpret(memory["output"]), + flex.out_data.reinterpret(out0), *out0.stride(), + *out0_flex, + flex.lhs_data.reinterpret(x), x.stride(0), x.stride(1), x.stride(2), + flex.lhs_data.scale, + flex.rhs_data.reinterpret(w), w.stride(0), w.stride(1), w.stride(2), w.stride(2) != 1, + flex.rhs_data.scale, + mx_ctx.weight_scale, mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n, mx_scale_stride_n != 1, + bias, bias_stride, + x.shape[1], + x.shape[1] if routing_data.expt_hist is None else None, + N, K, + betas, gammas, + None if gather_indx is None else gather_indx.src_indx, + None if scatter_indx is None else scatter_indx.src_indx, + num_indx, + writeback_idxs, writeback_size, + expt_data.hist, expt_data.offs, expt_data.offs_sum, expt_data.blocks, + batch_size, grid_m, grid_n, + out_alpha, + *epilogue.fn_arg_values_matmul, + routing_data.n_expts_tot, routing_data.n_expts_act, + precision_config.max_num_imprecise_acc, + precision_config.allow_tf32, + precision_config.flexpoint_saturate_inf, + flex.rhs_data.is_per_batch, + opt_flags.block_m, + opt_flags.block_n, + opt_flags.block_k, + opt_flags.group_m, + XCD_SWIZZLE=opt_flags.xcd_swizzle, + SWIZZLE_MX=mx_ctx.swizzle_mx, + EPILOGUE_SUBTILE=opt_flags.epilogue_subtile, + SPLIT_K=opt_flags.split_k, + EVEN_K=K % opt_flags.block_k == 0, + W_CACHE_MODIFIER=opt_flags.w_cache_modifier, + TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt, + num_warps=opt_flags.num_warps, + num_stages=opt_flags.num_stages, + arch=opt_flags.arch, + UPCAST_INDICES=should_upcast_indices(x, w, out0), + DISABLE_Y_TMA=out0.stride(-2) * out0.dtype.itemsize % 16 != 0, + SWAP_XW=swap_xw, + NUM_SMS = n_cta, + **opt_flags.target_kernel_kwargs) + # post-processing + out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_data.offs, + num_indx, precision_config, routing_data, + postprocessing_features, memory, epilogue) + + # remove split-k + out = out.squeeze(0) + if not is_input_batched: + out = out.view(out.shape[-2], out.shape[-1]) + return out + + +# ----------------------------------------------------------------------------- +# Reference Implementation +# ----------------------------------------------------------------------------- + +def matmul_ogs_torch(x, w, bias, + routing_data: RoutingData = None, + gather_indx: GatherIndx = None, + scatter_indx: ScatterIndx = None, + precision_config: PrecisionConfig = None, + betas = None, + gammas = None, + round_x = None, round_y = None, + ): + is_input_batched = x.ndim == 3 + assert x.dtype.itemsize > 1 + assert w.dtype.itemsize > 1 + if is_input_batched: + assert gather_indx is None, "gather not supported in batched mode" + assert scatter_indx is None, "scatter not supported in batched mode" + assert routing_data is None, "routing not supported in batched mode" + assert w.ndim == 3 and w.shape[0] == x.shape[0] + if round_x is None: + round_x = lambda x: x + if round_y is None: + round_y = lambda x: x + if w.ndim == 2: + w = w.view(1, w.shape[0], w.shape[1]) + if x.ndim == 2: + x = x.view(1, x.shape[0], x.shape[1]) + if routing_data is None: + routing_data = RoutingData(None, None, w.shape[0], 1) + n_expts_act = routing_data.n_expts_act + # memory offsets + if routing_data.n_expts_tot > 1 and not is_input_batched: + sizes = routing_data.expt_hist + off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32) + off[1:] = torch.cumsum(sizes, 0) + offs = list(itertools.pairwise(off)) + else: + offs = [[0, x.shape[1]] for _ in range(w.shape[0])] + # compute + n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0] + y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype) + for i, (lo, hi) in enumerate(offs): + if gather_indx is None: + idx = torch.arange(lo, hi, device=x.device) + else: + idx = gather_indx.src_indx[lo:hi] // n_expts_act + batch = i if is_input_batched else 0 + out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(), + w[i, :, :].float()) + if bias is not None: + out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None] + if gammas is not None: + out *= gammas[lo:hi, None] + y[batch, lo:hi, :] = round_y(out) + if not is_input_batched: + y = y.view(y.shape[1], y.shape[2]) + if scatter_indx is None: + return y + # accumulate output from all experts + n_rows = y.shape[0] // n_expts_act + out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device) + for i, (lo, hi) in enumerate(offs): + dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act + msk = dst_idx != -1 + out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float() + return out diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_common.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_common.py new file mode 100644 index 0000000..28f920b --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_common.py @@ -0,0 +1,96 @@ +import triton +import triton.language as tl + +# ----------------------------------------------------------------------------- +# Utilities +# ----------------------------------------------------------------------------- + + +@triton.jit +def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr): + """ + Swizzle the program id based on integer XCD_SWIZZLE. + This is useful for reording how blocks are ordered. A scheduler may, for example, + assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2. + This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment + becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to + the same hardware unit. + """ + # Number of pids per group in the new arrangement + pids_per_group = domain_size // XCD_SWIZZLE + extra_pid_groups = domain_size % XCD_SWIZZLE + + # Compute current current and local pid within the group + group = pid % XCD_SWIZZLE + local_pid = pid // XCD_SWIZZLE + + # Calculate new pid based on the new grouping + new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid + return new_pid + + +@triton.jit +def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr): + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + return pid_m, pid_n + + +def make_matmul_repr(base_name, order): + + def matmul_repr(specialization): + signature = specialization.signature + constants = specialization.constants + reorder = lambda L: [L[i] for i in order] + layout = lambda stride: "N" if stride in constants else "T" + convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype + dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in reorder(["Y", "X", "W"])]) + layouts = "".join([f"{layout(i)}" for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])]) + blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]]) + # mode = [] + # if "GatherIndx" not in constants: + # mode += ['g'] + # if "ScatterSrcIndx" not in constants: + # mode += ['s'] + # suffix = "" if not mode else "_o" + (''.join(mode)) + # if base_name.startswith("_p"): + # suffix += "_ptma" + return f"{base_name}_{layouts}_{dtypes}_{blocks}" + + return matmul_repr + + +def matmul_launch_metadata(grid, kernel, args): + ret = dict() + M, N, K = args["M"], args["N"], args["K"] + Y, X, W = args["Y"], args["X"], args["W"] + hist = args["ExptHist"] + if hist is not None: + n_tokens = float(hist.sum()) + n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum() + + # If annotation is given, use that to generate name for profiling. + tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION") + n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else int(hist.float().mean()) + else: + n_tokens = None + n_w_bytes = W.numel() * W.element_size() + repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}" + nbits = X.dtype.itemsize * 8 + batch_repr = "" + if "batch_size" in args and args["batch_size"] > 1: + batch_repr = repr("B", args["batch_size"]) + ", " + ret["name"] = f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}]" + fM = M if M is not None else n_tokens + fK = K if K is not None else n_tokens + ret[f"flops{nbits}"] = 2.0 * fM * N * fK + gindx = args.get("GatherIndx", None) + sindx = args.get("WriteBackIndx", None) + sskipped = 0. if sindx is None else (sindx == -1).sum() / sindx.shape[0] + gskipped = 0. if gindx is None else (gindx == -1).sum() / gindx.shape[0] + ret["bytes"] = int((1 - sskipped) * Y.numel() * Y.element_size() + (1 - gskipped) * X.numel() * X.element_size() + + n_w_bytes) + return ret diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_finalize_matmul.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_finalize_matmul.py new file mode 100644 index 0000000..5a14bbc --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_finalize_matmul.py @@ -0,0 +1,325 @@ +import triton +import triton.language as tl +from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, update_scale +from triton_kernels.target_info import cuda_capability_geq as _cuda_capability_geq +from triton_kernels.target_info import is_hip as _is_hip + + +@tl.constexpr_function +def is_hip(): + return _is_hip() + + +@tl.constexpr_function +def cuda_capability_geq(x, y): + return _cuda_capability_geq(x, y) + + +@tl.constexpr_function +def log2(n): + return len(bin(n)) - 3 + + +@tl.constexpr_function +def _permute_to_end_order(n: int, axis: int): + """ + Returns the order of the axes of a tensor to permute `axis` to the end. + """ + order = tuple(range(n)) + return order[:axis] + order[(axis + 1):] + (axis, ) + + +@triton.jit +def permute_to_end(x, axis: tl.constexpr): + """ + Permutes `x` so that `axis` is the last axis. + """ + N: tl.constexpr = len(x.shape) + return tl.permute(x, _permute_to_end_order(N, axis).value) + + +@triton.jit +def split_n(x, N: tl.constexpr): + """ + Given `x`, a tensor of shape AxB...x2x2...x2, split it N times. + Return a tuple of the results. + """ + xs = (x, ) + for i in tl.static_range(N): + next = tl.split(xs[0]) + for j in tl.static_range(2**i - 1): + next = next + tl.split(xs[j + 1]) + xs = next + return xs + + +@triton.jit +def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr = None, NUM_THREADS: tl.constexpr = None): + N: tl.constexpr = tl.extra.cuda.num_threads() if NUM_THREADS is None else NUM_THREADS + BS: tl.constexpr = x.numel if BLOCK_SIZE is None else BLOCK_SIZE + tl.static_assert(BS % N == 0, "BLOCK_SIZE must be divisible by NUM_THREADS") + return tl.max(tl.reshape(tl.abs(x), [N, BS // N], can_reorder=True), axis=1) + + +def _finalize_matmul_launch_metadata(grid, kernel, args): + ret = dict() + Out, A, ScatterSrcIndx, FinalizeScatterIdxs, K, M, N, EXPT_PER_TOK, NumRows = [ + args[name] + for name in ["Out", "A", "ScatterSrcIndx", "FinalizeScatterIdxs", "K", "M", "N", "EXPT_PER_TOK", "NumRows"] + ] + ret["name"] = f"{kernel.name} [M={M}x{EXPT_PER_TOK} {N=} {K=}]" + + if FinalizeScatterIdxs is not None: + M = FinalizeScatterIdxs[-1].item() + + if ScatterSrcIndx is not None: + is_active = (ScatterSrcIndx != -1).view((-1, EXPT_PER_TOK)) + n_active = is_active.sum(dim=1) + need_accum = n_active >= (1 if K > 1 else 2) + is_active &= need_accum[:, None] + active_input_rows = is_active.sum() + active_output_rows = need_accum.sum() + if EXPT_PER_TOK > 1: + # Masked rows are set to zero. + active_output_rows += (n_active == 0).sum() + else: + if NumRows is not None: + if isinstance(NumRows, int): + active_input_rows = NumRows + else: + active_input_rows = NumRows.item() + else: + active_input_rows = M + active_output_rows = M + + ret["bytes"] = (active_input_rows * K * A.shape[-1] * A.element_size() + + active_output_rows * Out.shape[-1] * Out.element_size()) + if FinalizeScatterIdxs is not None: + ret["bytes"] += FinalizeScatterIdxs.numel() * FinalizeScatterIdxs.element_size() + elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1: + ret["bytes"] += ScatterSrcIndx.numel() * ScatterSrcIndx.element_size() + nbits = Out.dtype.itemsize * 8 + ret[f"flops{nbits}"] = active_input_rows * K * A.shape[-1] + return ret + + +@tl.constexpr_function +def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None): + """ + Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision + adds. If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside + that register. + + Generates code something like: + + add.f32.f16 $0, $2, $1; + add.f32.f16 $0, $3, $0; + add.f32.f16 $0, $4, $0; + add.f32.f16 $0, $5, $0; + + .reg .f32 b; + abs.f32 b, $0; + max.f32 my_abs_max, my_abs_max, b; + """ + # Add the first f16 value to the input $1, store into the output $0. + ptx = f"\nadd.f32.{src_type} $0, $2, $1;" + # Accumulate the rest of the inputs into the output $0. + for i in range(1, n_inputs): + ptx += f"\nadd.f32.{src_type} $0, ${2 + i}, $0;" + if absmax_reg_name is not None: + # Update `absmax_reg_name` with the absolute maximum value seen so far. + ptx += f""" + .reg .f32 b; + abs.f32 b, $0; + max.f32 {absmax_reg_name}, {absmax_reg_name}, b; + """ + # Return the PTX snippet, brace-enclosed so we don't pollute the global namespace. + return f"{{{ptx}}}" + + +@triton.jit +def _mixed_precision_accumulate_and_track_absmax(acc, x, axis: tl.constexpr, absmax_reg_name: tl.constexpr = None): + """Given an fp8/bf16/fp16 tensor, accumulate into `acc` along `axis`. + Values are first converted to bf16/fp16, packed into 32-bit registers, and then accumulated using + mixed-precision adds. + + If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside that + register. + """ + REDUCTION_SIZE: tl.constexpr = x.shape[axis] + tl.static_assert(2**log2(REDUCTION_SIZE) == REDUCTION_SIZE, + f"Reduction size must be a power of 2, was {REDUCTION_SIZE}") + # move `axis` to the last axis and reshape for iterative splitting. + x = permute_to_end(x, axis) + x = tl.reshape(x, x.shape[:-1] + (2, ) * log2(REDUCTION_SIZE)) + # Split into a tuple of AxB tensors. + xs = split_n(x, log2(REDUCTION_SIZE)) + if (tl.constexpr(x.dtype == tl.float8e4nv) or tl.constexpr(x.dtype == tl.float8e5)): + # Convert fp8 to fp16. + fp16_xs = () + for i in tl.static_range(len(xs)): + fp16_xs += (xs[i].to(tl.float16), ) + xs = fp16_xs + src_type: tl.constexpr = "f16" + elif x.dtype == tl.float16: + src_type: tl.constexpr = "f16" + elif x.dtype == tl.bfloat16: + src_type: tl.constexpr = "bf16" + else: + tl.static_assert(False, f"Unsupported dtype: {x.dtype}") + return tl.inline_asm_elementwise( + _accumulate_f16_into_f32_and_track_absmax_ptx(REDUCTION_SIZE, src_type, absmax_reg_name), + "=r,r" + (",h" * len(xs)), + (acc, ) + xs, + dtype=tl.float32, + is_pure=True, + pack=1, + ) + + +def _finalize_matmul_repr(specialization): + signature = specialization.signature + suffix = "" if "ScatterSrcIndx" in specialization.constants else "_scatter" + return f"_finalize_matmul{suffix}_{signature['A'][1:]}" + + +@triton.jit(repr=_finalize_matmul_repr, launch_metadata=_finalize_matmul_launch_metadata) +def _finalize_matmul( + Out, + OutExpectedScale, + OutActualScale, + OutChecksumScale, + A, + stride_a_k, + stride_a_m, + AScale, + ScatterSrcIndx, + FinalizeScatterIdxs, + K: tl.constexpr, + M, + N, + NumRows, + EPILOGUE_FN: tl.constexpr, + epilogue_fn_args, + EXPT_PER_TOK: tl.constexpr, + flexpoint_saturate_inf: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGES: tl.constexpr, + HAS_FUSED_SCRATCHPAD: tl.constexpr, +): + if HAS_FUSED_SCRATCHPAD: + # Bump A to the scratchpad region. + A += tl.cast(M, tl.int64) * stride_a_m + + USE_FUSED_MIXED_PREC_ACC: tl.constexpr = (cuda_capability_geq(10, 0) + and tl.constexpr(A.dtype.element_ty != tl.float32)) + USE_FUSED_ABSMAX: tl.constexpr = USE_FUSED_MIXED_PREC_ACC and OutActualScale is not None + + THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads() + local_max = tl.full([THREADS_PER_BLOCK], 0.0, tl.float32) + if USE_FUSED_ABSMAX: + local_max = tl.inline_asm_elementwise( + """ + .reg .f32 my_abs_max; + mov.b32 my_abs_max, 0; + mov.b32 $0, 0; + """, "=r,r", [local_max], dtype=tl.float32, is_pure=False, pack=1) + + out_scale = load_scale(OutExpectedScale) + a_scale = load_scale(AScale) + + if FinalizeScatterIdxs is not None: + MBound = tl.load(FinalizeScatterIdxs + M + M * EXPT_PER_TOK) + if tl.program_id(0) >= MBound: + return + else: + MBound = M + + if NumRows is not None: + NumRows = NumRows # remove constexpr + if NumRows.dtype.is_ptr(): + NumRows = tl.load(NumRows) + + if FinalizeScatterIdxs is not None or (ScatterSrcIndx is not None and EXPT_PER_TOK > 1): + n_active_experts = 0 + else: + n_active_experts: tl.constexpr = EXPT_PER_TOK + + for pid_m in tl.range(tl.program_id(0), MBound, tl.num_programs(0)): + src_offs = pid_m * EXPT_PER_TOK + tl.arange(0, EXPT_PER_TOK) + if FinalizeScatterIdxs is not None: + row = tl.load(FinalizeScatterIdxs + pid_m) + src_idxs = tl.load(FinalizeScatterIdxs + M + src_offs) + n_active_experts = tl.sum((src_idxs != -1).to(tl.int32)) + elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1: + row = pid_m + src_idxs = tl.load(ScatterSrcIndx + src_offs) + n_active_experts = tl.sum((src_idxs != -1).to(tl.int32)) + else: + row = pid_m + src_idxs = src_offs + if NumRows is not None: + src_idxs = tl.where(src_idxs < NumRows, src_idxs, -1) + + if n_active_experts == 0: + for off_n in tl.range(tl.program_id(1) * BLOCK_N, N, tl.num_programs(1) * BLOCK_N): + offs_n = off_n + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + tl.store(Out + row * N + offs_n, tl.zeros([BLOCK_N], dtype=Out.dtype.element_ty), mask=n_mask) + else: + for off_n in tl.range(tl.program_id(1) * BLOCK_N, N, tl.num_programs(1) * BLOCK_N, num_stages=STAGES): + offs_n = off_n + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + if is_hip(): + if EXPT_PER_TOK > 1: + src_idxs_tup = split_n(tl.reshape(src_idxs, (2, ) * log2(EXPT_PER_TOK)), log2(EXPT_PER_TOK)) + else: + # Convert 1D tensor to 1D tuple. + src_idxs_tup = tl.split(tl.reshape(tl.join(src_idxs, src_idxs), (2, )))[:1] + for i in tl.static_range(0, EXPT_PER_TOK, 1): + src_idx = src_idxs_tup[i] + if src_idx != -1: + As = A + src_idx.to(tl.int64) * stride_a_m + offs_n + for ki in tl.static_range(K): + acc += tl.load(As, mask=n_mask, other=0.0) + As += stride_a_k + else: + As = A + src_idxs.to(tl.int64)[:, None] * stride_a_m + offs_n[None, :] + for ki in tl.static_range(K): + a = tl.load(As, mask=(src_idxs != -1)[:, None] & n_mask[None, :], other=0.0) + As += stride_a_k + if USE_FUSED_MIXED_PREC_ACC: + acc = _mixed_precision_accumulate_and_track_absmax( + acc, a, axis=0, + absmax_reg_name="my_abs_max" if USE_FUSED_ABSMAX and ki == K - 1 else None) + else: + acc += tl.sum(a, dtype=tl.float32, axis=0) + acc = acc * a_scale + if not USE_FUSED_ABSMAX and OutActualScale is not None: + local_max = tl.maximum(local_max, thread_local_absmax(acc)) + acc = float_to_flex(acc, out_scale if OutExpectedScale is not None else None, None, OutChecksumScale, + None, Out, flexpoint_saturate_inf) + if EPILOGUE_FN is not None: + acc = EPILOGUE_FN(acc, *epilogue_fn_args, target_dtype=Out.dtype.element_ty, + pid=row * tl.num_programs(1) + tl.program_id(1)) + tl.store(Out + row * N + offs_n, acc, mask=n_mask) + + persisent_m = tl.num_programs(0) < MBound + if not persisent_m and n_active_experts == 0: + # Skip updating the scale if there were no active experts and this is a non-persistent launch. + # The loop ran only once, and inside it we only stored zeros. + return + + if USE_FUSED_ABSMAX: + local_max = tl.inline_asm_elementwise( + "mov.b32 $0, my_abs_max;", + "=r,r", + [local_max], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + local_max *= a_scale + update_scale(local_max, OutActualScale, Out) diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_matmul_ogs.py new file mode 100644 index 0000000..d37baf1 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -0,0 +1,347 @@ +import triton +import triton.language as tl +from triton_kernels.numerics_details.mxfp import _unswizzle_mx_block, get_scaled_dot_format_string +from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale +from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle + +# fmt: off + +@triton.jit +def _zero_masked_rows( + pid_m, pid_n, + Y, stride_y_m, stride_y_n, + N, + ScatterSrcIndx, num_idxs, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M) + offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) + src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) + YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n + mask_n = offs_n < N + mask = (src_idx == -1)[:, None] & mask_n[None, :] + tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask) + + +_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2]) +@triton.jit(repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) +def _matmul_ogs( + Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n, + YExpectedScale, YActualScale, YChecksumScale, + X, stride_x_z, stride_x_m, stride_x_k, + XScale, + W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr, + WScale, + MxScale, stride_mx_e, stride_mx_k, stride_mx_n, MX_TRANSPOSE: tl.constexpr, + B, stride_b_e, # Bias + NRows, M, N, K, # shapes + # expt data + Betas, Gammas, + GatherIndx, + ScatterSrcIndx, num_idxs, + WriteBackIndx, writeback_size, + ExptHist, ExptOffs, ExptOffsSum, ExptData, + # true grid size + batch_size, grid_m, grid_n, + # Out scale + out_alpha, + # epilogue transform + EPILOGUE_FN: tl.constexpr, epilogue_fn_args, + # MoE config + N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + # precision config + MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, + FLEXPOINT_SATURATE_INF: tl.constexpr, + PER_BATCH_SCALE: tl.constexpr, + # optimization config + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, SWIZZLE_MX: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, + W_CACHE_MODIFIER: tl.constexpr, + NUM_SMS: tl.constexpr, + TOKENS_PER_EXPT_FOR_ANNOTATION=None, + UPCAST_INDICES: tl.constexpr = False, + DISABLE_Y_TMA: tl.constexpr = True, + SWAP_XW: tl.constexpr = False): + + Y = Out # Y is passed for the purposes of annotation; replace it with Out + is_microscaled_format: tl.constexpr = MxScale is not None + MX_PACK_DIVISOR: tl.constexpr = 32 + if is_microscaled_format: + w_type: tl.constexpr = W.dtype.element_ty + tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), + "mx_weight_ptr must be uint8") + tl.static_assert(MxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") + tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") + + pid = tl.program_id(0) + if ExptOffsSum is not None and XCD_SWIZZLE > 1: + # Determine how much padding there is on the expert data. This allows us to + # know the true grid size and avoid processing padding tiles. + padding_m = grid_m - tl.load(ExptOffsSum) + else: + padding_m: tl.constexpr = 0 + + HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None + index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32 + + total_actual_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K + if padding_m > 0 and pid >= total_actual_tiles: + tl.device_assert(batch_size == 0) + pid_mn = pid - total_actual_tiles + if pid_mn < padding_m * grid_n: + pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M) + + # set masked out rows to 0 + if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1: + _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, N, ScatterSrcIndx, num_idxs, BLOCK_M, BLOCK_N) + return + + # swizzle program ids + pid_emnk = pid + if XCD_SWIZZLE != 1: + pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE) + pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K) + pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K) + pid_k = pid_mnk % SPLIT_K + pid_mn = pid_mnk // SPLIT_K + pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M) + # For split-k, advance to the output k slice + if SPLIT_K > 1: + Y += pid_k.to( index_type) * stride_y_k + # set masked out rows to 0 + if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1: + _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, N, ScatterSrcIndx, num_idxs, BLOCK_M, BLOCK_N) + # unpack expert data + if ExptData is None: + tl.static_assert(M is not None) + expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m + else: + tl.static_assert(M is None) + expt_data = tl.load(ExptData + pid_m) + if expt_data == -1: + return + expt_id = expt_data & 0x0000FFFF + block_id = expt_data >> 16 + M = tl.load(ExptHist + expt_id) + start_m = tl.load(ExptOffs + expt_id) + start_z = 0 + expt_id, block_id = expt_id.to(index_type), block_id.to(index_type) + start_m, start_z = start_m.to(index_type), start_z.to(index_type) + pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type) + # A pointers + offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M) + offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M) + X += start_z * stride_x_z + if GatherIndx is None: + X += start_m * stride_x_m + else: + GatherIndx += start_m + # no needs to bounds-check here because `offs_x_m` wraps around M dim + offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT + offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K) + XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k + # B pointers + offs_w_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % N, BLOCK_N), BLOCK_N) + + # TODO: refactor if/else when triton front end improves + if is_microscaled_format: + # We have pack 2 fp4 values in a byte + W_PACK_DIVISOR: tl.constexpr = 2 if W.dtype.element_ty == tl.uint8 else 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR + MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR + + MxScale += expt_id * stride_mx_e + + if SWIZZLE_MX: + tl.static_assert(BLOCK_N % 128 == 0) + tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0) + PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4 + offs_inner = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK) + offs_n_scale = (pid_n * (BLOCK_N // 128) + tl.arange(0, BLOCK_N // 128)) % N + offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, BLOCK_N // 128), BLOCK_N // 128) + + MxScalePtrs = MxScale + offs_n_scale.to(index_type)[:, None] * stride_mx_n + offs_inner[None, :] + else: + offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + offs_n_scale = offs_w_n + # K dimension must be the last dimension for the scales + MxScalePtrs = MxScale + offs_k_scale.to(index_type)[None, :] * stride_mx_k + offs_n_scale.to(index_type)[:, None] * stride_mx_n + else: + MxScalePtrs = None + offs_k_scale = None + W_PACK_DIVISOR: tl.constexpr = 1 + MX_SCALE_BLOCK_K: tl.constexpr = 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K + + offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W) + W += expt_id * stride_w_e + WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n) + # compute output + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)): + if EVEN_K: + mask_k = tl.full([BLOCK_K], True, dtype=tl.int1) + mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1) + mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) + else: + mask_k = offs_k < k + mask_k_w = offs_w_k < tl.cdiv(k, W_PACK_DIVISOR) + if is_microscaled_format and not SWIZZLE_MX: + mask_k_scale = offs_k_scale < tl.cdiv(k, MX_PACK_DIVISOR) + + x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER) + if is_microscaled_format: + x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) + mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) + if x_format == "fp16" or x_format == "bf16": + x_scales: tl.constexpr = None + else: + # Scale of 1 in E8M0 format + x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8) + if SWIZZLE_MX: + w_scales = _unswizzle_mx_block(tl.load(MxScalePtrs)) + else: + w_scales = tl.load(MxScalePtrs, mask=mask_k_scale[None, :], other=0.0) + acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True) + if SWIZZLE_MX: + MxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_mx_k + else: + MxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_mx_k + else: + acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k + WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k + # bias + scale + offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M) + offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id * stride_b_e + offs_y_n + if pid_k == 0: + bias = tl.load(BPtrs, mask=mask_n, other=0) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + if Betas is not None: + betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0) + else: + betas = tl.full([BLOCK_M], 1, dtype=tl.float32) + if Gammas is not None: + gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) + else: + gammas = tl.full([BLOCK_M], 1, dtype=tl.float32) + # flexpoint + x_scale = load_scale(XScale) + if PER_BATCH_SCALE: + w_scale = load_scale(WScale + expt_id) + else: + w_scale = load_scale(WScale) + acc *= x_scale * w_scale + acc = acc + bias[None, :] * betas[:, None] + acc *= gammas[:, None] + if out_alpha is not None: + acc *= out_alpha + # write-back + Y += start_z.to(index_type) * stride_y_z + if WriteBackIndx is not None: + WriteBackIndx += start_m + dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1) + mask_m = mask_m & (dst_idx != -1) + offs_y_m = dst_idx + else: + Y += start_m * stride_y_m + offs_y_m = offs_m + + YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n + mask = mask_m[:, None] & mask_n[None, :] + acc = float_to_flex(acc, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF) + if EPILOGUE_FN is not None: + acc = EPILOGUE_FN(acc, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty) + tl.store(YPtrs, acc, mask=mask) + + +# Imagine N_EXPTS_ACT = 4, n_final_rows = 5, and n_scratchpad_rows = 8. +# Also imagine scatter_indx.src_indx is: +# (number of active experts per final row) +# -1 -1 0 -1 1 +# -1 2 -1 -1 1 +# 1 3 -1 -1 2 +# -1 4 5 6 3 +# -1 -1 -1 -1 0 (this row is unused) +# +# Then, row 0 and 1 can be written directly to the final tensor. +# In this case, WriteBackIndx looks like: +# [0] = 0 : intermediate row 0 is written directly to final row 0 +# [1] = 5+1=6 : scratchpad starts at offset 5 +# [2] = 1 : intermediate row 2 is written directly to final row 1 +# [3] = 5+3=8 +# [4] = 5+4=9 +# [5] = 5+5=10 +# [6] = 5+6=11 +# [7] = -1 : unused (there are only seven intermediate rows) +@triton.jit +def _compute_writeback_idx( + WriteBackIndx, + FinalizeScatterIdxs, + ScatterDstIndx, ScatterSrcIndx, + n_final_rows, n_scratchpad_rows, + BLOCK_M: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, +): + tl.static_assert(N_EXPTS_ACT > 1) + + pid_m = tl.program_id(0) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m < n_scratchpad_rows + dst_idxs = tl.load(ScatterDstIndx + offs_m, mask=mask_m, other=-1) + # Load corresponding rows in ScatterSrcIndx. + mask = dst_idxs != -1 + src_offs = (dst_idxs // N_EXPTS_ACT) * N_EXPTS_ACT + src_offs = src_offs[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :] + src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask[:, None], other=-1) + # Compute the number of actually active experts. + is_src_active = (src_idxs != -1).to(tl.int32) + has_one_active = tl.sum(is_src_active, axis=1) == 1 + # Compute the writeback index. + wb_idx = tl.where(has_one_active, dst_idxs // N_EXPTS_ACT, n_final_rows + offs_m) + wb_idx = tl.where(mask, wb_idx, -1) + tl.store(WriteBackIndx + offs_m, wb_idx, mask=mask_m) + + if pid_m >= ((n_final_rows + BLOCK_M - 1) // BLOCK_M): + return + + mask_m = offs_m < n_final_rows + src_offs = offs_m[:, None] * N_EXPTS_ACT + tl.arange(0, N_EXPTS_ACT)[None, :] + src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask_m[:, None], other=-1) + is_src_active = (src_idxs != -1).to(tl.int32) + has_one_active = tl.sum(is_src_active, axis=1) == 1 + + need_finalize_scatter = mask_m and not has_one_active + finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32)) + if finalize_scatter_count == 0: + return + pp_off = tl.atomic_add(FinalizeScatterIdxs + n_final_rows + n_scratchpad_rows, finalize_scatter_count) + + # need_finalize_scatter = [1, 0, 0, 1, 1, 0, 1, 0, 1] + # arange = [0, 1, 2, 3, 4, 5, 6, 7, 8] + arange = tl.arange(0, BLOCK_M) + # idxs = [0, _, _, 3, 4, _, 6, _, 8] + last = BLOCK_M - 1 + idxs = tl.where(need_finalize_scatter, arange, last) + # idxs = [0, 3, 4, 6, 8, _, _, _, _] + idxs = tl.sort(idxs) + # r = offs_m + # d = [r[0], r[3], r[4], r[6], r[8], r[-1], r[-1], r[-1], r[-1]] + d = tl.gather(offs_m, idxs, axis=0) + s = tl.gather(src_idxs, idxs.expand_dims(1).broadcast_to(src_idxs.shape), axis=0) + # store destination indices + Ptr = FinalizeScatterIdxs + pp_off + tl.store(Ptr + arange, d, mask=arange < finalize_scatter_count) + # store src indices + Ptr = FinalizeScatterIdxs + n_final_rows + pp_off * N_EXPTS_ACT + tl.store(Ptr + N_EXPTS_ACT * arange[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :], s, mask=(arange < finalize_scatter_count)[:, None]) diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py new file mode 100644 index 0000000..dcf1f94 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -0,0 +1,555 @@ +import functools +import torch +import triton +import triton.language as tl +from triton_kernels import target_info +from triton_kernels.numerics_details.mxfp import _unswizzle_mx_block, get_scaled_dot_format_string +from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, nan_propagating_absmax_reduce, compute_scale +from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle + +# fmt: off + +@tl.constexpr_function +def cuda_capability_geq(major, minor): + return target_info.cuda_capability_geq(major, minor) + +# TODO: this is a limitation of the triton frontend +# we shouldn't have to do that! +def inline_function(f): + """ + Wraps an arbitrary Python function so that it can be inlined into a Triton function at compile-time. + """ + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + # disguise the function as a Triton builtin to avoid raising an error + # that we're calling a non-JIT function from within a Triton kernel: + wrapper.__triton_builtin__ = True + wrapper.__module__ = getattr(tl, "__name__", "triton.language") + return wrapper + +@inline_function +def _load_tensor_desc(desc, offs, transpose: tl.constexpr = False, _builder=None): + if transpose: + offs = offs[:-2] + [offs[-1], offs[-2]] + res = desc.load(offs, _builder=_builder) + res = tl.reshape(res, desc.block_shape[-2:], _builder=_builder) + if transpose: + res = tl.trans(res, _builder=_builder) + return res + + +# Helper function to recreate a TMA desc with the same fields, but with a new pointer and optional new shape. +@inline_function +def _update_tensor_desc(desc, ptr, shape=None, _builder=None): + return tl.make_tensor_descriptor( + ptr, + shape=shape or desc.shape, + # last dim must be constexpr 1; reflecting the old descriptor drops the constexpr + strides=desc.strides[:-1] + [tl.constexpr(1)], + block_shape=desc.block_shape, + _builder=_builder, + ) + +@triton.jit +def _make_tensor_desc(ptr, shape, strides, block_shape, transpose: tl.constexpr = False): + tl.static_assert(len(shape) == len(strides)) + tl.static_assert(len(strides) == len(block_shape)) + if transpose: + return tl.make_tensor_descriptor( + ptr, + shape=shape[:-2] + [shape[-1], shape[-2]], + strides=strides[:-2] + [strides[-1], tl.constexpr(1)], + block_shape=block_shape[:-2] + [block_shape[-1], block_shape[-2]], + ) + else: + return tl.make_tensor_descriptor( + ptr, + shape=shape, + strides=strides[:-1] + [tl.constexpr(1)], + block_shape=block_shape, + ) + +@triton.jit +def _load_tile_attrs( + tile_id, num_tiles, grid_m, grid_n, padding_m, + M, ExptData, ExptHist, ExptOffs, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr, + GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr): + # unpack and swizzle program ids + pid_emnk = tile_id + if XCD_SWIZZLE != 1: + pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE) + pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K) + pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K) + if SPLIT_K > 1: + pid_k = pid_mnk % SPLIT_K + pid_mn = pid_mnk // SPLIT_K + else: + pid_k: tl.constexpr = 0 + pid_mn = pid_mnk + pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M) + + # unpack expert data + if ExptData is None: + tl.static_assert(M is not None) + expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1 + else: + tl.static_assert(M is None) + expt_data = tl.load(ExptData + pid_m) + expt_id = expt_data & 0x0000FFFF + block_id = expt_data >> 16 + eM = tl.load(ExptHist + expt_id) + start_m = tl.load(ExptOffs + expt_id) + start_z = 0 + + off_m = BLOCK_M * block_id + off_n = BLOCK_N * pid_n + + return expt_id, start_z, start_m, eM, off_m, off_n, pid_k + + +@triton.jit +def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask): + mask = mask & (offs < writeback_size) + offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1) + mask = offs != -1 + return (offs, mask) + + +_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2]) +@triton.jit(repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) +def _p_matmul_ogs( + Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n, + YExpectedScale, YActualScale, YChecksumScale, + X, stride_x_z, stride_x_m, stride_x_k, + XScale, + W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr, + WScale, + MxScale, stride_mx_e, stride_mx_k, stride_mx_n, MX_TRANSPOSE: tl.constexpr, + B, stride_b_e, # Bias + NRows, M, N, K, # shapes + # expt data + Betas, Gammas, + GatherIndx, + ScatterSrcIndx, num_idxs, + WriteBackIndx, writeback_size, + ExptHist, ExptOffs, ExptOffsSum, ExptData, + # true grid size + batch_size, grid_m, grid_n, + # Out scale + out_alpha, + # epilogue transform + EPILOGUE_FN: tl.constexpr, epilogue_fn_args, + # MoE config + N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + # precision config + MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, + FLEXPOINT_SATURATE_INF: tl.constexpr, + PER_BATCH_SCALE: tl.constexpr, + # optimization config + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, SWIZZLE_MX: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, + W_CACHE_MODIFIER: tl.constexpr, + NUM_SMS: tl.constexpr, + TOKENS_PER_EXPT_FOR_ANNOTATION=None, + UPCAST_INDICES:tl.constexpr=False, + DISABLE_Y_TMA: tl.constexpr=False, + SWAP_XW: tl.constexpr = False): + Y = Out # Y is passed for the purposes of annotation; replace it with Out + + is_microscaled_format: tl.constexpr = MxScale is not None + MX_PACK_DIVISOR: tl.constexpr = 32 + if is_microscaled_format: + w_type: tl.constexpr = W.dtype.element_ty + tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), + "mx_weight_ptr must be uint8") + tl.static_assert(MxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") + tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") + + # We have pack 2 fp4 values in a byte + W_PACK_DIVISOR: tl.constexpr = 2 if W.dtype.element_ty == tl.uint8 else 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR + MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR + else: + W_PACK_DIVISOR: tl.constexpr = 1 + MX_SCALE_BLOCK_K: tl.constexpr = 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K + + if ExptOffsSum is not None: + # Determine how much padding there is on the expert data. This allows us to + # know the true grid size and avoid processing padding tiles. + padding_m = grid_m - tl.load(ExptOffsSum) + else: + padding_m: tl.constexpr = 0 + + HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None + index_type: tl.constexpr = tl.int64 + + # set masked out rows to 0 + if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1: + # Iterate with reversed pids so that later pids will get more tiles if the number of + # tiles isn't evenly divisible by the number of SMs. + # The main loop after this iterates in the forward direction such that earlier + # pids get more tiles if the number of tiles isn't evenly divisible. + # This helps balance the work across the SMs. + for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS): + pid_k = pid_mnk % SPLIT_K + pid_mn = pid_mnk // SPLIT_K + pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) + + offs_m = BLOCK_M * pid_m + tl.arange(0, BLOCK_M) + offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) + src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) + YPtrs = Y + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n + mask_n = offs_n < N + mask = (src_idx == -1)[:, None] & mask_n[None, :] + tl.store(YPtrs + pid_k * stride_y_k, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask) + + USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None + + INT_MAX: tl.constexpr = 2147483647 + HAS_TMA_GS: tl.constexpr = cuda_capability_geq(10, 0) + USE_GATHER_TMA: tl.constexpr = (HAS_TMA_GS and GatherIndx is not None) + X_USE_LOAD_TMA: tl.constexpr = GatherIndx is None and not USE_GATHER_TMA + USE_SCATTER_TMA: tl.constexpr = (HAS_TMA_GS and HAS_FUSED_SCATTER) and not DISABLE_Y_TMA + + if USE_GATHER_TMA: + x_desc = tl.make_tensor_descriptor( + X, + # No masking on the M dimension because we manually mask by setting indices to -1 + shape=[INT_MAX, K], + strides=[stride_x_m, stride_x_k], + block_shape=[1, BLOCK_K] + ) + elif X_USE_LOAD_TMA: + x_desc = tl.make_tensor_descriptor( + X, + # When M is ragged, we don't mask the input rows, but mask the accumulator result in the epilogue. + # So shape[0] here is the global number of rows in the X matrix, which allows using an invariant descriptor. + shape=[NRows, K], + strides=[stride_x_m, stride_x_k], + block_shape=[BLOCK_M, BLOCK_K] + ) + + w_desc = _make_tensor_desc(W, + shape=[N_EXPTS_TOT if ExptData is not None else batch_size, + (K + W_PACK_DIVISOR - 1) // W_PACK_DIVISOR, N], + strides=[stride_w_e, stride_w_k, stride_w_n], + block_shape=[1, PACKED_BLOCK_K_W, BLOCK_N], + transpose=W_TRANSPOSE) + + if is_microscaled_format: + PackedK = (K + MX_PACK_DIVISOR - 1) // MX_PACK_DIVISOR + if SWIZZLE_MX: + mx_desc = tl.make_tensor_descriptor( + MxScale, + shape=[ + N_EXPTS_TOT if ExptData is not None else batch_size, + (N + 127) // 128, (PackedK + 3) // 4, 32, 4 * 4, + ], + strides=[stride_mx_e, stride_mx_n, stride_mx_k, 4 * 4, 1], + block_shape=[1, BLOCK_N // 128, MX_SCALE_BLOCK_K // 4, 32, 4 * 4] + ) + else: + mx_desc = _make_tensor_desc( + MxScale, + shape=[N_EXPTS_TOT if ExptData is not None else batch_size, PackedK, N], + strides=[stride_mx_e, stride_mx_k, stride_mx_n], + block_shape=[1, MX_SCALE_BLOCK_K, BLOCK_N], + transpose=MX_TRANSPOSE + ) + + EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // 2 if EPILOGUE_SUBTILE else BLOCK_N + + if USE_SCATTER_TMA: + y_desc = tl.make_tensor_descriptor( + Y, + # No masking on the M dimension because we manually mask by setting indices to INT_MAX + shape=[INT_MAX - 1, N], + strides=[stride_y_m, stride_y_n], + block_shape=[1, EPILOGUE_BLOCK_N], + ) + + k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K) + num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K + + # If true, do not share loop-carried variables between the prologue and the + # epilogue to enable better pipelining with mmav5 + INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0) + + # start negative; will be incremented at the top of the loop + if INDEPENDENT_EPILOGUE: + tile_id1 = tl.program_id(0) - NUM_SMS + + # Keep track of local max for updating flexpoint scales. + THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads() + local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32) + + DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256 + # Enable warp specialization when all loads are TMA loads. Don't enable it + # for mixed-precision yet. + ENABLE_WS: tl.constexpr = True + WARP_SPECIALIZE: tl.constexpr = (USE_GATHER_TMA or X_USE_LOAD_TMA) and ENABLE_WS + + for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=WARP_SPECIALIZE): + expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs( + tile_id, num_tiles, grid_m, grid_n, padding_m, + M, ExptData, ExptHist, ExptOffs, + BLOCK_M, BLOCK_N, SPLIT_K, + GROUP_M, XCD_SWIZZLE) + + # Base pointers and offsets. These will be DCE'ed if unused in the TMA path. + XBase = X + start_z.to(index_type) * stride_x_z + offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k + if SPLIT_K > 1: + offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k + offs_w_n = off_n + tl.arange(0, BLOCK_N) + offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % N, BLOCK_N), BLOCK_N) + + # If the operands are swapped, the TMA layout of the MX scales are not optimal for the weights anymore. + # The scales will be loaded with normal loads instead. + if is_microscaled_format and SWAP_XW: + offs_mx_k = tl.arange(0, MX_SCALE_BLOCK_K) + + PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4 + offs_mx_inner = tl.arange(0, PACKED_MX_BLOCK) + offs_mx_outer = ((off_n // 128) + tl.arange(0, BLOCK_N // 128)) % N + offs_mx_outer = tl.max_contiguous(tl.multiple_of(offs_mx_outer, BLOCK_N // 128), BLOCK_N // 128) + + if SPLIT_K > 1: + offs_mx_k += MX_SCALE_BLOCK_K * pid_k + offs_mx_inner += PACKED_MX_BLOCK * pid_k + + if X_USE_LOAD_TMA: + if ExptData is None: + # start_z may change; update the descriptor + x_desc = _update_tensor_desc(x_desc, XBase) + else: + offs_m = off_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < (M if M is not None else eM) + if USE_GATHER_TMA: + # Mask the gather indices and load -1 instead. TMA will handle OOB accesses. + offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, + mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT + if ExptData is None: # start_z may change; update the descriptor + x_desc = _update_tensor_desc(x_desc, XBase) + else: + if M is not None: + offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M) + else: + offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M) + # no needs to bounds-check here because `offs_m` wraps around M dim + offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT + offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m + + acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32) + for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER): + off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K + off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K + off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K + + if USE_GATHER_TMA: + x = x_desc.gather(offs_x_m, off_k) + elif X_USE_LOAD_TMA: + x = x_desc.load([start_m + off_m, off_k]) + else: + XPtrs = XBase + offs_x_m + offs_x_k + XBase += BLOCK_K * SPLIT_K * stride_x_k + if EVEN_K: + if SPLIT_K > 1: + x = tl.load(XPtrs, mask=off_k < K, other=0.0) + else: + x = tl.load(XPtrs) + else: + mask_k = tl.arange(0, BLOCK_K) < K - off_k + x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + + w = _load_tensor_desc(w_desc, [expt_id, off_k_w, off_n], transpose=W_TRANSPOSE) + + if is_microscaled_format: + x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) + mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) + if x_format == "fp16" or x_format == "bf16": + x_scales: tl.constexpr = None + else: + x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8) + if SWAP_XW: + if SWIZZLE_MX: + MxPtrs = MxScale + expt_id.to(index_type) * stride_mx_e + offs_mx_outer.to(index_type)[:, None] * stride_mx_n + offs_mx_inner[None, :] + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_mx_k + w_scales = _unswizzle_mx_block(tl.load(MxPtrs)) + else: + MxPtrs = MxScale + expt_id.to(index_type) * stride_mx_e + offs_mx_k.to(index_type)[None, :] * stride_mx_k + offs_w_n.to(index_type)[:, None] * stride_mx_n + ki * MX_SCALE_BLOCK_K * SPLIT_K * stride_mx_k + if EVEN_K: + if SPLIT_K > 1: + w_scales = tl.load(MxPtrs, mask=off_k < K, other=0.0) + else: + w_scales = tl.load(MxPtrs) + else: + mask_k = offs_mx_k < tl.cdiv(K - off_k, MX_PACK_DIVISOR) + w_scales = tl.load(MxPtrs, mask=mask_k[None, :], other=0.0) + + elif SWIZZLE_MX: + w_scales = mx_desc.load([expt_id, off_n // 128, ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0]) + w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * 32 * 4 * 4)) + w_scales = _unswizzle_mx_block(w_scales) + else: + w_scales = _load_tensor_desc(mx_desc, [expt_id, off_k_mx, off_n], transpose=MX_TRANSPOSE).T + if SWAP_XW: + acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True) + else: + acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True) + else: + if SWAP_XW: + acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + else: + acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + + if INDEPENDENT_EPILOGUE: + tile_id1 += NUM_SMS + expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs( + tile_id1, num_tiles, grid_m, grid_n, padding_m, + M, ExptData, ExptHist, ExptOffs, + BLOCK_M, BLOCK_N, SPLIT_K, + GROUP_M, XCD_SWIZZLE) + else: + tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM + off_m1, off_n1, pid_k1 = off_m, off_n, pid_k + + # Determine output row offsets and mask + offs_m = off_m1 + tl.arange(0, BLOCK_M) + mask_m = offs_m < M if M is not None else offs_m < eM1 + if HAS_FUSED_SCATTER: + offs_y_m, mask_m = _load_writeback_idx_and_mask( + WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m) + # Later, mask out the acc for computing flexpoint scales. + MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE + + if USE_SCATTER_TMA and SPLIT_K > 1: + # Compute the split k offset in number of rows, and add it to offs_y_m. + # This allows us to write to the correct slice in the output tensor while using + # a 2D TMA scatter. + tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m)) + split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m) + offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m) + else: + offs_y_m = start_m1 + offs_m + + if USE_GATHER_TMA: + MASK_ACC: tl.constexpr = False + else: + # Later, mask out the acc for computing flexpoint scales. + MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE + + # TMA is faster on Blackwell if a SWAP_XW transpose is not needed, or when we need registers to mask out the acc. + Y_USE_TMA: tl.constexpr = (MASK_ACC or cuda_capability_geq(10, 0)) and not (DISABLE_Y_TMA or SWAP_XW) + + YBase = Y + start_z1.to(index_type) * stride_y_z + start_m1.to(index_type) * stride_y_m + if USE_SCATTER_TMA: + if ExptData is None: # start_z1 may change; update the descriptor + y_desc = _update_tensor_desc(y_desc, YBase) + elif not HAS_FUSED_SCATTER and Y_USE_TMA: + y_desc = tl.make_tensor_descriptor( + YBase + pid_k1.to(index_type) * stride_y_k, + shape=[M if M is not None else eM1, N], + strides=[stride_y_m, stride_y_n], + block_shape=[BLOCK_M, EPILOGUE_BLOCK_N], + ) + + # bias + scale + offs_y_n = off_n1 + tl.arange(0, BLOCK_N) + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id1 * stride_b_e + offs_y_n + if pid_k1 == 0: + bias = tl.load(BPtrs, mask=mask_n, other=0) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + if Betas is not None: + betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0) + else: + betas = tl.full([BLOCK_M], 1, dtype=tl.float32) + if Gammas is not None: + gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0) + else: + gammas = tl.full([BLOCK_M], 1, dtype=tl.float32) + x_scale = load_scale(XScale) + if PER_BATCH_SCALE: + w_scale = load_scale(WScale + expt_id1) + else: + w_scale = load_scale(WScale) + + if EPILOGUE_SUBTILE: + accs = tl.split(tl.permute(tl.reshape(acc, (BLOCK_M, 2, EPILOGUE_BLOCK_N)), (0, 2, 1))) + biases = tl.split(tl.permute(tl.reshape(bias, (2, EPILOGUE_BLOCK_N)), (1, 0))) + else: + accs = (acc,) + biases = (bias,) + + for a_i in tl.static_range(len(accs)): + acc_tile = accs[a_i] + acc_tile *= x_scale * w_scale + + if SWAP_XW: + acc_tile = acc_tile.T + acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None] + acc_tile *= gammas[:, None] + if out_alpha is not None: + acc_tile *= out_alpha + + if MASK_ACC: + acc_tile = tl.where(mask_m[:, None], acc_tile, 0.0) + + # Flexpoint + acc_view = tl.reshape( + acc_tile, [acc_tile.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True) + local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(acc_view, axis=0)) + acc_tile = float_to_flex( + acc_tile, YExpectedScale, + None, # ActualScale: local absmax is tracked and updated after the loop + YChecksumScale, + None, # mask: acc is manually masked to 0 + Y, FLEXPOINT_SATURATE_INF) + if EPILOGUE_FN is not None: + acc_tile = EPILOGUE_FN(acc_tile, *epilogue_fn_args, target_dtype=Y.dtype.element_ty, pid=len(accs)*tile_id1 + a_i) + + if USE_SCATTER_TMA: + # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that + # there shouldn't be any other negative values. + offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True) + y_desc.scatter(acc_tile.to(Y.dtype.element_ty), offs_y_m, off_n1 + a_i * EPILOGUE_BLOCK_N) + elif not HAS_FUSED_SCATTER and Y_USE_TMA: + y_desc.store([off_m1, off_n1 + a_i * EPILOGUE_BLOCK_N], acc_tile.to(Y.dtype.element_ty)) + else: + offs_y_n = off_n1 + a_i * EPILOGUE_BLOCK_N + tl.arange(0, EPILOGUE_BLOCK_N) + mask_n = offs_y_n < N + + YPtrs = Y + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n + mask = mask_m[:, None] & mask_n[None, :] + tl.store(YPtrs, acc_tile, mask=mask) + + + # Update the flexpoint scales + if YActualScale is not None: + tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), Y), sem="relaxed") + + +_per_device_alloc_fns = {} +def get_per_device_per_stream_alloc_fn(device): + if device not in _per_device_alloc_fns: + _per_stream_tensors = {} + def alloc_fn(size: int, alignment: int, stream): + assert alignment == 128 + if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size: + _per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8) + _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"} + return _per_stream_tensors[stream] + + _per_device_alloc_fns[device] = alloc_fn + return _per_device_alloc_fns[device] diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_weight_transpose.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_weight_transpose.py new file mode 100644 index 0000000..f16b585 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/_weight_transpose.py @@ -0,0 +1,55 @@ +import triton +import triton.language as tl + + +@triton.jit +def _weight_transpose( + M: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + X, + stride_xe: tl.constexpr, + stride_xm: tl.constexpr, + stride_xn: tl.constexpr, + Y, + stride_ye: tl.constexpr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, +): + pid_m = tl.program_id(0).to(tl.int64) + pid_n = tl.program_id(1).to(tl.int64) + pid_e = tl.program_id(2).to(tl.int64) + + X += stride_xe * pid_e + Y += stride_ye * pid_e + + m_exact: tl.constexpr = (M % BLOCK_M) == 0 + n_exact: tl.constexpr = (N % BLOCK_N) == 0 + + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + mask_m = off_m < M + mask_n = off_n < N + + if m_exact: + if n_exact: + mask = None + other = None + else: + mask = mask_n[None, :] + other = 0 + else: + if n_exact: + mask = mask_m[:, None] + other = 0 + else: + mask = mask_m[:, None] & mask_n[None, :] + other = 0 + + X_ptrs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Y_ptrs = Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn + + tile = tl.load(X_ptrs, mask=mask, other=other) + tl.store(Y_ptrs, tile, mask=mask) diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/fast_contiguous.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/fast_contiguous.py new file mode 100644 index 0000000..b8f5d9b --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/fast_contiguous.py @@ -0,0 +1,54 @@ +import torch +from ._weight_transpose import _weight_transpose + + +def weight_transpose(X, Y, BLOCK_M=128, BLOCK_N=128, num_warps=16) -> None: + if X.dtype.itemsize == 1: + X = X.view(torch.int8) + if Y.dtype.itemsize == 1: + Y = Y.view(torch.int8) + + # check compatibility: + assert X.shape == Y.shape + assert X.dtype == Y.dtype + + # this doubles up as an assertion: + is_3d = {3: True, 2: False}[len(X.shape)] + + M = X.shape[-2] + N = X.shape[-1] + E = X.shape[0] if is_3d else 1 + + stride_xm = X.stride(-2) + stride_xn = X.stride(-1) + stride_xe = X.stride(0) if is_3d else 0 + + stride_ym = Y.stride(-2) + stride_yn = Y.stride(-1) + stride_ye = Y.stride(0) if is_3d else 0 + + grid = ((M + BLOCK_M - 1) // BLOCK_M, (N + BLOCK_N - 1) // BLOCK_N, E) + + _weight_transpose[grid]( + M, + N, + BLOCK_M, + BLOCK_N, + X, + stride_xe, + stride_xm, + stride_xn, + Y, + stride_ye, + stride_ym, + stride_yn, + num_warps=num_warps, + ) + + +def fast_contiguous(X): + if X.is_contiguous(): + return X + Y = torch.empty(X.shape, device=X.device, dtype=X.dtype) + weight_transpose(X, Y) + return Y diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/metadata.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/metadata.py new file mode 100644 index 0000000..bb03eda --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/metadata.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +import torch +import triton +import triton.language as tl + + +@dataclass +class ExptData: + hist: torch.Tensor + offs: torch.Tensor + offs_sum: torch.Tensor + blocks: torch.Tensor + buffer: torch.Tensor + + +@triton.jit +def _matmul_metadata_memset(Hist, n_expts_tot, MDHist, MDTokStarts, MDTileStarts, MDTileInfo, md_n_tiles, + BLOCK: tl.constexpr, TILE_DIM: tl.constexpr): + pid = tl.program_id(0) + # if pid == 0 - initialize cumsums + if pid == 0: + x_tok = tl.zeros([BLOCK], dtype=MDTokStarts.dtype.element_ty) + x_tile = tl.zeros([BLOCK], dtype=MDTileStarts.dtype.element_ty) + tl.store(MDTokStarts, 0) + tl.store(MDTileStarts, 0) + for i in range(0, n_expts_tot, BLOCK): + offs_n = tl.arange(0, BLOCK) + i + mask = offs_n < n_expts_tot + hist_tok = tl.load(Hist + offs_n, mask=mask) + hist_tile = tl.cdiv(hist_tok, TILE_DIM) + tok_starts = tl.cumsum(hist_tok, 0) + x_tok + x_tok += tl.sum(hist_tok, 0).to(MDTokStarts.dtype.element_ty) + tile_starts = tl.cumsum(hist_tile, 0) + x_tile + x_tile += tl.sum(hist_tile, 0).to(MDTileStarts.dtype.element_ty) + tl.store(MDHist + offs_n, hist_tok, mask=mask) + tl.store(MDTokStarts + 1 + offs_n, tok_starts, mask=mask) + tl.store(MDTileStarts + 1 + offs_n, tile_starts, mask=mask) + + # initialize block data + offs = pid * BLOCK + tl.arange(0, BLOCK) + tl.store(MDTileInfo + offs, 0xffffffff, mask=offs < md_n_tiles) + + +@triton.jit +def _matmul_metadata_compute(Hist, MDTileStarts, MDTileInfo, BLOCK: tl.constexpr, TILE_DIM: tl.constexpr): + + expt_id = tl.program_id(0) + n_tokens = tl.load(Hist + expt_id) + n_blocks = tl.cdiv(n_tokens, TILE_DIM) + + tile_off = tl.load(MDTileStarts + expt_id) + MDTileInfo += tile_off + # MDTileInfo += tl.load(MDTilesStart + expt_id) + for block_off in range(0, n_blocks, BLOCK): + block_offs = block_off + tl.arange(0, BLOCK) + data = (block_offs << 16) + expt_id + tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks) + + +def compute_metadata(routing_data, n_rows, block_m): + if routing_data.expt_hist is None: + return ExptData(None, None, None, None, None) + MEMSET_BLOCK = 512 + HIST2_BLOCK_M = 512 + device = routing_data.expt_hist.device + n_expts_tot = routing_data.n_expts_tot + cdiv = triton.cdiv + if n_rows <= n_expts_tot: + grid_m = n_rows + else: + grid_m = n_expts_tot - 1 - ((n_expts_tot - n_rows - 1) // block_m) + metadata_size = 3 * n_expts_tot + 2 + grid_m + metadata = torch.empty(metadata_size, dtype=torch.int32, device=device) + md_hist = metadata[:n_expts_tot] + md_offs = metadata[n_expts_tot:n_expts_tot * 2 + 1] + md_offs_sum = metadata[3 * n_expts_tot + 2 - 1] + md_tile_starts = metadata[n_expts_tot * 2 + 1:n_expts_tot * 3 + 2] + md_tile_infos = metadata[n_expts_tot * 3 + 2:] + _matmul_metadata_memset[(cdiv(metadata_size, MEMSET_BLOCK), )]( + routing_data.expt_hist, n_expts_tot, md_hist, md_offs, md_tile_starts, md_tile_infos, md_tile_infos.shape[0], + BLOCK=MEMSET_BLOCK, # optimization parameters + TILE_DIM=block_m, # constants + ) + _matmul_metadata_compute[(n_expts_tot, )]( + routing_data.expt_hist, md_tile_starts, md_tile_infos, # outputs + BLOCK=HIST2_BLOCK_M, # optimization parameters + TILE_DIM=block_m, # constants + ) + return ExptData(md_hist, md_offs, md_offs_sum, md_tile_infos, metadata) diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags.py new file mode 100644 index 0000000..2a01b34 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags.py @@ -0,0 +1,277 @@ +from dataclasses import dataclass +import triton +from triton_kernels import target_info +import torch + +from . import opt_flags_amd, opt_flags_nvidia + +# fmt: off + +@dataclass +class OptFlags: + block_m: int + block_n: int + block_k: int + num_warps: int + num_stages: int + group_m: int + xcd_swizzle: int + w_cache_modifier: str + split_k: int + fused_scatter: bool + is_persistent: bool + epilogue_subtile: bool + arch: str + target_kernel_kwargs: dict + + def __post_init__(self): + if self.fused_scatter and self.split_k != 1: + raise ValueError("Not supported") + + + +def make_default_opt_flags_amd( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + microscaling_ctx, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + enforce_bitwise_invariance, + has_expensive_epilogue, + constraints, +): + assert not constraints, "flags constraints not supported on AMD" + # tokens per expert + if routing_data is None: + tokens_per_expt = m + elif routing_data.expected_tokens_per_expt is None: + tokens_per_expt = max(1, m // routing_data.n_expts_tot) + else: + tokens_per_expt = routing_data.expected_tokens_per_expt + # block_m + if constraints.get("block_m", None): + block_m = constraints["block_m"] + elif enforce_bitwise_invariance: + block_m = 128 + elif tokens_per_expt >= 512 and n >= 2048: + block_m = 128 + else: + block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64)) + if routing_data is not None: + grid_m = routing_data.n_blocks(m, block_m) + else: + grid_m = triton.cdiv(m, block_m) + # group_m: + group_m = 4 + # number of xcds + num_xcds = 8 + xcd_swizzle = num_xcds + # block_nk: + block_n, block_k = opt_flags_amd.compute_block_nk( + n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, microscaling_ctx + ) + # split_k: + grid_size = grid_m * ((n + block_n - 1) // block_n) + n_cu = torch.cuda.get_device_properties(0).multi_processor_count + if enforce_bitwise_invariance: + split_k = 1 + else: + split_k = max(1, n_cu // grid_size) + # w_cache_modifier: + w_cache_modifier = ".cg" if block_m <= 32 else None + # num_warps, num_stages + num_warps = 2 if (m is not None and m <= 16) else 8 + num_stages = 2 + is_persistent = False + # AMD-specific + target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1} + return OptFlags( + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_warps=num_warps, + num_stages=num_stages, + group_m=group_m, + xcd_swizzle=xcd_swizzle, + w_cache_modifier=w_cache_modifier, + split_k=split_k, + fused_scatter=False, + is_persistent=is_persistent, + epilogue_subtile=False, + arch=None, + target_kernel_kwargs=target_kernel_kwargs, + ) + +def make_default_opt_flags_nvidia( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + microscaling_ctx, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + enforce_bitwise_invariance, + has_expensive_epilogue, + constraints, +): + constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"] + assert not any([c not in constraints_supported for c in constraints]), constraints.keys() + # tokens per expert + if routing_data is None: + tokens_per_expt = m + elif routing_data.expected_tokens_per_expt is None: + tokens_per_expt = max(1, m // routing_data.n_expts_tot) + else: + tokens_per_expt = routing_data.expected_tokens_per_expt + # pid swizzling + group_m = 8 + xcd_swizzle = 1 + # block_m + if constraints.get("block_m", None): + block_m = constraints["block_m"] + elif enforce_bitwise_invariance: + block_m = 128 + else: + block_m = max(64, min(triton.next_power_of_2(tokens_per_expt), 128)) + # TODO: remove when triton is more optimized for H100 MXFP4 + arch = None + if ( + block_m < 128 + and rhs_dtype == torch.uint8 + and microscaling_ctx.weight_scale is not None + and not target_info.cuda_capability_geq(10, 0) + ): + arch = "sm80" + # block n + block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config) + # is_persistent + grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n) + n_sms = torch.cuda.get_device_properties(0).multi_processor_count + tiles_per_sm = grid_size / n_sms + supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9) + if constraints.get("is_persistent", None) is not None: + is_persistent = constraints["is_persistent"] + else: + has_simple_epilogue = precision_config.max_num_imprecise_acc is None and not has_expensive_epilogue + is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4 + # block k + if constraints.get("block_k", None) is not None: + block_k = constraints["block_k"] + else: + block_k = opt_flags_nvidia.compute_block_k(k, is_persistent, lhs_dtype, rhs_dtype, microscaling_ctx) + # split_k + if constraints.get("split_k", None) is not None: + split_k = constraints["split_k"] + elif is_persistent or enforce_bitwise_invariance: + split_k = 1 + else: + estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n) + split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size) + if split_k > 1: + # With split_k, results are written in f32. Use that for the following computations. + out_dtype = torch.float32 + compute_num_stages_args = ( + precision_config, + microscaling_ctx, + is_persistent, + block_m, + block_n, + block_k, + out_dtype, + lhs_dtype, + rhs_dtype, + ) + if constraints.get("epilogue_subtile", None) is not None: + epilogue_subtile = constraints["epilogue_subtile"] + else: + n1 = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, False, has_expensive_epilogue) + n2 = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, True, has_expensive_epilogue) + epilogue_subtile = n2 > n1 # enable epilogue_subtile if it increases the number of stages + # num_stages + num_stages = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, epilogue_subtile, has_expensive_epilogue) + # fused scatter scratchpad + if constraints.get("fused_scatter", None) is not None: + fused_scatter = constraints["fused_scatter"] + else: + fused_scatter = can_use_fused_scatter and split_k == 1 + # num_warps + num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n) + ret = OptFlags( + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_warps=num_warps, + num_stages=num_stages, + group_m=group_m, + xcd_swizzle=xcd_swizzle, + w_cache_modifier=None, + split_k=split_k, + fused_scatter=fused_scatter, + is_persistent=is_persistent, + epilogue_subtile=epilogue_subtile, + arch=arch, + target_kernel_kwargs=dict(), + ) + # check constraints + assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + return ret + +# -------------- +# User Interface +# -------------- + +_opt_flags_constraints: dict = dict() +_opt_flags: OptFlags | None = None + +def update_opt_flags_constraints(constraints: dict[str, int]): + global _opt_flags_constraints + _opt_flags_constraints.update(constraints) + +def reset_opt_flags_constraints(): + global _opt_flags_constraints + _opt_flags_constraints = None + +def set_opt_flags(opt_flags: OptFlags): + global _opt_flags + assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override" + assert not _opt_flags, "opt_flags already set; please reset to None first" + _opt_flags = opt_flags + + +def make_opt_flags( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + has_expensive_epilogue, +): + microscaling_ctx = precision_config.mx_ctx + enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance + if _opt_flags is not None: + assert not _opt_flags_constraints + return _opt_flags + args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, microscaling_ctx, m, n, k, + routing_data, can_use_persistent_tma, can_use_fused_scatter, + enforce_bitwise_invariance, has_expensive_epilogue, _opt_flags_constraints] + backend = triton.runtime.driver.active.get_current_target().backend + if backend == "hip": + return make_default_opt_flags_amd(*args) + if backend == "cuda": + return make_default_opt_flags_nvidia(*args) + assert False diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags_amd.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags_amd.py new file mode 100644 index 0000000..f0080f2 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags_amd.py @@ -0,0 +1,32 @@ +import torch +import triton +from triton_kernels.target_info import get_cdna_version + + +def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, microscaling_ctx): + lhs_width = lhs_dtype.itemsize + rhs_width = rhs_dtype.itemsize if microscaling_ctx.weight_scale is None else 0.5 + + # block_n: + n_cu = torch.cuda.get_device_properties(0).multi_processor_count + if n is not None: + if n <= 128 and (n & (n - 1)) == 0: + block_n = n + else: + block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))) + elif block_m > 64: + block_n = 256 + else: + block_n = 128 + + if get_cdna_version() == 4 and block_m == 128: + block_n = 512 + + # block_k needs to match the cacheline size (128B) + block_k = int(128 // min(lhs_width, rhs_width)) + + # TODO: block_k = 128 seems to work better for now. + # perhaps due to increased number of k loops to pipeline + if microscaling_ctx.weight_scale is not None: + block_k = 128 + return block_n, block_k diff --git a/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags_nvidia.py b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags_nvidia.py new file mode 100644 index 0000000..d268ff8 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/matmul_ogs_details/opt_flags_nvidia.py @@ -0,0 +1,111 @@ +import torch +import triton +from triton_kernels import target_info + + +def compute_grid_size(routing_data, m, n, block_m, block_n): + if routing_data is not None: + grid_m = routing_data.n_blocks(m, block_m) + else: + grid_m = triton.cdiv(m, block_m) + grid_n = (n + block_n - 1) // block_n + return grid_m * grid_n + + +def compute_block_n(n: int, arch, precision_config): + capability = torch.cuda.get_device_capability()[0] if arch is None else int(arch[2:-1]) + # block_n: + block_n = max(16, min(128, triton.next_power_of_2(n))) + if capability >= 9 and precision_config.max_num_imprecise_acc is None and n > 128: + block_n = 256 + return block_n + + +def compute_block_k(k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, mx_ctx): + has_mx_weight_scale = mx_ctx and mx_ctx.weight_scale is not None + lhs_width = lhs_dtype.itemsize + rhs_width = rhs_dtype.itemsize + if has_mx_weight_scale: + rhs_width = 0.5 + # block_k needs to match the cacheline size (128B) + block_k = int(128 // min(lhs_width, rhs_width)) + # TODO: revisit when Triton is better for H100 + MXFP4 + has_native_mxfp = target_info.cuda_capability_geq(10, 0) + if rhs_width == 0.5 and not has_native_mxfp: + block_k = 128 + elif k is not None: + block_k = max(32, min(triton.next_power_of_2(k), block_k)) + + if has_native_mxfp and is_persistent and has_mx_weight_scale: + # Cap block_k to conserve smem to increase num_stages + block_k = min(block_k, 128) + return block_k + + +def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int: + device_props = torch.cuda.get_device_properties(0) + n_sms = device_props.multi_processor_count + split_k = n_sms // grid_size + if k is not None: + # avoid split_k for small k + num_block_k = triton.cdiv(k, block_k) + split_k = min(split_k, num_block_k // 4) + split_k = max(split_k, 1) + return split_k + + +def compute_num_warps(block_m, block_n): + return max(block_m * block_n // 4096, 4) + + +def compute_num_stages( + precision_config, + microscaling_ctx, + is_persistent, + block_m, + block_n, + block_k, + out_dtype, + lhs_dtype, + rhs_dtype, + epilogue_subtile, + has_expensive_epilogue, +): + if precision_config.max_num_imprecise_acc is not None: + return 3 + weight_size = 0.5 if rhs_dtype == torch.uint8 else rhs_dtype.itemsize + stage_size = block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size + device_props = torch.cuda.get_device_properties(0) + smem_capacity = device_props.shared_memory_per_block_optin + has_native_mxfp = target_info.cuda_capability_geq(10, 0) + if has_native_mxfp and microscaling_ctx is not None: + if microscaling_ctx.weight_scale is not None: + if rhs_dtype == torch.uint8: + # 4-bit e2m1 weights are padded 2x + # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory + stage_size += block_k * block_n * weight_size + + if is_persistent: + # Per-stage wait barrier + stage_size += 8 + acc_size = out_dtype.itemsize + if target_info.cuda_capability_geq(10, 0): + acc_size = 4 if has_expensive_epilogue else out_dtype.itemsize + else: + acc_size = out_dtype.itemsize + if target_info.cuda_capability_geq(10, 0) and epilogue_subtile and not has_expensive_epilogue: + acc_block_n = block_n // 2 + else: + acc_block_n = block_n + # pipelined TMA store local to global, or + # pipelined layout conversion before store of the accumulator + # note: layout conversion has some padding + smem_capacity -= (block_m + 4) * acc_block_n * acc_size + if microscaling_ctx.weight_scale is not None: + # mx scales + stage_size += block_n * (block_k // 32) + elif has_native_mxfp: + # mx scales + stage_size += block_n * (block_k // 32) + num_stages = min(4, smem_capacity // int(stage_size)) + return num_stages diff --git a/kernel-microbench/tk/triton_kernels/numerics.py b/kernel-microbench/tk/triton_kernels/numerics.py new file mode 100644 index 0000000..024d3fc --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/numerics.py @@ -0,0 +1,42 @@ +import torch +from dataclasses import dataclass + +MAX_FINITE_FLOAT8E5 = 57344.0 +MAX_FINITE_FLOAT8E4NV = 448.0 +MAX_FINITE_FLOAT8E4B8 = 240.0 + + +@dataclass(frozen=True) +class BaseFlexData: + dtype: torch.dtype | None = None + + def view(self, x: torch.Tensor): + if self.dtype is None: + return x + return x.view(self.dtype) + + def reinterpret(self, x): + if self.dtype is None or x.dtype.itemsize > 1: + return x + return x.view(self.dtype) + + +@dataclass(frozen=True) +class InFlexData(BaseFlexData): + scale: torch.Tensor | None = None + + @property + def is_per_batch(self): + return False if self.scale is None else len(self.scale) > 1 + + +@dataclass(frozen=True) +class OutFlexData(BaseFlexData): + expected_scale: torch.Tensor | None = None + actual_scale: torch.Tensor | None = None + checksum_scale: torch.Tensor | None = None + + def __iter__(self): + yield self.expected_scale + yield self.actual_scale + yield self.checksum_scale diff --git a/kernel-microbench/tk/triton_kernels/numerics_details/__init__.py b/kernel-microbench/tk/triton_kernels/numerics_details/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel-microbench/tk/triton_kernels/numerics_details/flexpoint.py b/kernel-microbench/tk/triton_kernels/numerics_details/flexpoint.py new file mode 100644 index 0000000..9f9075b --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/numerics_details/flexpoint.py @@ -0,0 +1,195 @@ +from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 +from triton_kernels import target_info +import triton +import triton.language as tl + +# ------------------------------- +# Kernels stuff +# ------------------------------- + +TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5) +TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV) +TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8) +TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750) +TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0) + +TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16 +TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9 +TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8 +TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1 +TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16 + + +@triton.jit +def max_finite(dtype): + if dtype == tl.constexpr(tl.float8e5): + return TL_MAX_FINITE_FLOAT8E5 + elif dtype == tl.constexpr(tl.float8e4nv): + return TL_MAX_FINITE_FLOAT8E4NV + elif dtype == tl.constexpr(tl.float8e4b8): + return TL_MAX_FINITE_FLOAT8E4B8 + elif dtype == tl.constexpr(tl.float8e4b15): + return TL_MAX_FINITE_FLOAT8E4B15 + elif dtype == tl.constexpr(tl.float16): + return TL_MAX_FINITE_FLOAT16 + else: + tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint") + + +@triton.jit +def rcp_max_finite(dtype): + if dtype == tl.constexpr(tl.float8e5): + return TL_RCP_MAX_FINITE_FLOAT8E5 + elif dtype == tl.constexpr(tl.float8e4nv): + return TL_RCP_MAX_FINITE_FLOAT8E4NV + elif dtype == tl.constexpr(tl.float8e4b8): + return TL_RCP_MAX_FINITE_FLOAT8E4B8 + elif dtype == tl.constexpr(tl.float8e4b15): + return TL_RCP_MAX_FINITE_FLOAT8E4B15 + elif dtype == tl.constexpr(tl.float16): + return TL_RCP_MAX_FINITE_FLOAT16 + else: + tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint") + + +@tl.constexpr_function +def cuda_capability_geq(major, minor): + return target_info.cuda_capability_geq(major, minor) + + +@triton.jit +def sm86_min_nan_xorsign_abs_f32(a, b): + """Wrapper for min.NaN.xorsign.abs.f32 PTX instruction. + + Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs. + NaN inputs are propagated to the output. + + Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do). + """ + tl.static_assert(cuda_capability_geq(8, 6), "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+") + tl.static_assert(a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs") + tl.static_assert(b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs") + + return tl.inline_asm_elementwise( + """{ + min.NaN.xorsign.abs.f32 $0, $1, $2; + }""", + "=r,r,r", + [a, b], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + + +@triton.jit +def sm86_max_nan_xorsign_abs_f32(a, b): + """Wrapper for max.NaN.xorsign.abs.f32 PTX instruction. + + Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs. + NaN inputs are propagated to the output. + + Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do). + """ + tl.static_assert(cuda_capability_geq(8, 6), "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+") + tl.static_assert(a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs") + tl.static_assert(b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs") + + return tl.inline_asm_elementwise( + """{ + max.NaN.xorsign.abs.f32 $0, $1, $2; + }""", + "=r,r,r", + [a, b], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + + +@triton.jit +def load_scale(scale_ptr): + return 1.0 if scale_ptr is None else tl.load(scale_ptr) + + +@triton.jit +def flex_to_float(x, scale_ptr): + scale = load_scale(scale_ptr) + return x.to(tl.float32) * scale + + +@triton.jit +def clip(x, limit): + res = tl.minimum(x, limit) + res = tl.maximum(-limit, res) + return res + + +@triton.jit +def nan_propagating_absmax_reduce(x, axis=None): + if cuda_capability_geq(8, 6): + # abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported. + x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32) + # Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it. + x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF + else: + # Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float) + masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF + x_absmax = tl.max(masked_abs_x, axis) + + return x_absmax + + +@triton.jit +def compute_scale(x, Out): + x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True)) + + # atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000). + # We use integer minimum because NaNs are above +inf in integer representation. + x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True) + RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty) + return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30) + + +@triton.jit +def update_scale(x, scale_ptr, Out) -> None: + if scale_ptr is not None: + scale = compute_scale(x, Out) + tl.atomic_max(scale_ptr, scale, sem="relaxed") + + +@triton.jit +def float_to_flex( + x, + expected_scale_ptr_or_val, + actual_scale_ptr, + checksum_scale_ptr, + mask, + Out, + saturate_infs: tl.constexpr, +): + if expected_scale_ptr_or_val is not None: + if expected_scale_ptr_or_val.dtype.is_ptr(): + invscale = 1.0 / tl.load(expected_scale_ptr_or_val) + else: + invscale = 1.0 / expected_scale_ptr_or_val + else: + invscale = 1.0 + if checksum_scale_ptr is not None: + x_int32 = x.to(tl.int32, bitcast=True) + zero = tl.cast(0.0, tl.int32) + if mask is not None: + x_int32 = tl.where(mask, x_int32, zero) + checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0) + tl.atomic_add(checksum_scale_ptr, checksum_local) + if mask is not None: + if actual_scale_ptr is not None: + x = tl.where(mask, x, 0.0) + update_scale(x, actual_scale_ptr, Out) + x = x * invscale + # if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case. + if expected_scale_ptr_or_val is not None: + if saturate_infs: + CLIP_VALUE = max_finite(Out.dtype.element_ty) + x = clip(x, CLIP_VALUE) + return x diff --git a/kernel-microbench/tk/triton_kernels/numerics_details/mxfp.py b/kernel-microbench/tk/triton_kernels/numerics_details/mxfp.py new file mode 100644 index 0000000..0feb113 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/numerics_details/mxfp.py @@ -0,0 +1,808 @@ +from enum import Enum +import triton +import triton.language as tl +import torch +import torch.nn.functional as F + +# ----------------------------------------------------------------------------- +# Dequantization / Quantization Utilities +# ----------------------------------------------------------------------------- + + +def get_max_quant_val(dtype: torch.dtype): + d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0} + assert dtype in d + return d[dtype] + + +@tl.constexpr_function +def get_scaled_dot_format_string(dtype: tl.dtype): + mapping = { + tl.float16: "fp16", + tl.bfloat16: "bf16", + tl.uint8: "e2m1", + tl.float8e4nv: "e4m3", + tl.float8e5: "e5m2", + } + return mapping[dtype] + + +@triton.jit +def _get_max_quant_val(dtype: tl.constexpr): + if dtype == tl.uint8: + return 6.0 + elif dtype == tl.float8e5: + return 57344.0 + elif dtype == tl.float8e4nv: + return 448.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + + +# fmt: off +@triton.jit +def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0): + is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // 32 + + # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16 + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation + abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32]) + max_val = tl.max(abs_tensor, axis=2, keep_dims=True) + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + if DEQUANT_SCALE_ROUNDING_MODE == 0: + # DequantScaleRoundingMode.ROUND_UP + # compute 2 ** ceil(log2(dequant_scale)) + # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros + # A corner case: exponent is 0xFF that will overflow but that's already + # NaN so assume we don't care. + dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + else: + # DequantScaleRoundingMode.ROUND_DOWN + # compute 2 ** floor(log2(dequant_scale)) + assert DEQUANT_SCALE_ROUNDING_MODE == 1 + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32]) + quant_tensor = f32_tensor * quant_scale + + # Reshape the tensors after scaling + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format. + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE]) + + # First, we simply extract the exponent part of the scales and store the result + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + # Now we must convert the tensors to the mx format. + if is_fp8: + out_tensor = quant_tensor.to(mx_tensor_dtype) + else: + quant_tensor = quant_tensor.to(tl.uint32, bitcast=True) + signs = quant_tensor & 0x80000000 + exponents = (quant_tensor >> 23) & 0xFF + mantissas = (quant_tensor & 0x7FFFFF) + + # 0.25 <= x < 0.75 maps to 0.5, a denormal number + E8_BIAS = 127 + E2_BIAS = 1 + # Move implicit bit 1 at the beginning to mantissa for denormals + adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False) + mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas) + + # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. + exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine sign, exponent, and mantissa, while saturating + # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7) + e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8) + + e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2]) + evens, odds = tl.split(e2m1_value) + out_tensor = evens | (odds << 4) + + return out_tensor, dequant_scale_exponent + +@triton.jit +def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr, + mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant, + src_ptr, stride_src_outer, stride_src_quant, + outer_dim, quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr): + + tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.") + tl.static_assert(BLOCK_SIZE_QUANT_DIM % 32 == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32") + + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5), + f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.") + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8") + tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16), f"{src_dtype=} must be bfloat16 or float16") + is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + K_DIVISOR: tl.constexpr = 1 if is_fp8 else 2 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer + mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer + mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer + + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + + mask_src_quant = start_src_quant + offs_src_quant < quant_dim + mask_n = start_out + offs_outer < outer_dim + full_mask_src = mask_src_quant and mask_n + + mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_mxt = mask_mxt_quant and mask_n + + scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, 32) + full_scale_mask = scale_mask_k and mask_n + + src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer + mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer + mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer + src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src) + + out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype, + DEQUANT_SCALE_ROUNDING_MODE) + + tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt) + + +@triton.jit +def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, + mx_scale_ptr, stride_scale_outer, stride_scale_quant, + mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr, + outer_dim, quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr): + + tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx") + tl.static_assert(BLOCK_SIZE_QUANT_DIM % 32 == 0, "BLOCK_SIZE_K must be a multiple of 32") + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16) + tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5), + "mx_tensor_ptr must be uint8") + tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") + + # Determine if we are dealing with fp8 types. + is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + K_DIVISOR: tl.constexpr = 1 if is_fp8 else 2 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + # Compute starting indices for the quantized (packed) dimension and the outer dimension. + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer + mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer + out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant + + # Compute offsets and masks. + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + + mask_outer = start_out + offs_outer < outer_dim + mask_out_quant = start_out_quant + offs_out_quant < quant_dim + full_mask_out = mask_out_quant and mask_outer + + mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_src = mask_src_quant and mask_outer + + mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, 32) + full_scale_mask = mask_scale and mask_outer + + tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer + scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer + out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer + + # Load the packed tensor and scale. + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + # Upcast the scale to the destination type. + if dst_dtype == tl.bfloat16: + dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True) + else: + tl.static_assert(dst_dtype == tl.float16) + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + dst_scale = dst_scale.to(tl.float16) + + # Now upcast the tensor. + if is_fp8: + dst_tensor = tensor.to(dst_dtype) + if tensor.dtype == tl.float8e5: + from_e_bits: tl.constexpr = 5 + from_m_bits: tl.constexpr = 2 + to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5 + to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10 + + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits + non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits + dst_tensor = tl.where( + (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src, + (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(dst_dtype, bitcast=True), + dst_tensor, + ) + else: + dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15 + dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800 + dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10 + # e2m1 + em0 = tensor & 0x07 + em1 = tensor & 0x70 + x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12) + x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True) + + # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) + scale = scale.reshape(dst_scale.shape) + + out_tensor = dst_tensor * dst_scale + # Correct any NaNs encoded via the scale. + out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out) + + +class DequantScaleRoundingMode(Enum): + ROUND_UP = 0 + ROUND_DOWN = 1 + + +SWIZZLE_ALIGN_INNER = 8 +SWIZZLE_SIZE_INNER = 4 +SWIZZLE_SIZE_OUTER = 128 + +@triton.jit +def _unswizzle_mx_block(x, + SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER, + SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER, + ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER): + shape_0: tl.constexpr = x.shape[0] + shape_1: tl.constexpr = x.shape[1] + tl.static_assert(shape_1 % SIZE_OUTER == 0) + tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER) + x = x.reshape(shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER) + x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER) + return x + + +def axis_permute_order(ndim: int, axis: int, swizzle_axis: int | None = None) -> list[int]: + permute_order = list(range(ndim)) + permute_order[axis], permute_order[-1] = permute_order[-1], permute_order[axis] + + scale_permute_order = permute_order.copy() + if swizzle_axis is not None: + axis = axis if axis >= 0 else axis + ndim + swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim + if swizzle_axis == ndim - 1: + swizzle_axis = axis + scale_permute_order[swizzle_axis], scale_permute_order[-2] = scale_permute_order[-2], scale_permute_order[swizzle_axis] + + convert_order = [i for i, (a, b) in enumerate(zip(permute_order, scale_permute_order)) if a != b] + assert len(convert_order) == 0 or len(convert_order) == 2, "Exactly 0 or 1 swap should be required to transform permute_order to scale_permute_order." + return permute_order, scale_permute_order, convert_order + + +def transpose_shape(shape: tuple[int, ...], i: int, j: int) -> tuple[int, ...]: + shape = list(shape) + shape[i], shape[j] = shape[j], shape[i] + return tuple(shape) + + +def permute_shape(shape: tuple[int, ...], permute_order: list[int]) -> tuple[int, ...]: + return tuple(shape[i] for i in permute_order) + + +def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int, swizzle_axis: int | None = None, + out_quant_tensor: torch.Tensor | None = None, out_scale: torch.Tensor | None = None, + DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode=DequantScaleRoundingMode.ROUND_UP, + BLOCK_OUT_DIM: int = 128, BLOCK_QUANT_DIM: int = 32): + """ + Convert the src weights to mx format. The src weight is quantized along the axis dimension. + + If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte. + Note that this means the k_dim of the tensor will be half of the logical k_dim. + + If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored + in their respective formats. + + When swizzle_axis is provided, the downcast will quantize along the quantization axis and swizzle these values + with the swizzle_axis from layout (A, B, ..., N, K) to (A, B, ..., N // 128, K // 4, 32, 4, 4), where N is the + swizzle dimension and K is the quantization dimension. The swizzled scales are then reinterpreted back as + (A, B, ..., N, K), contiguous along K, and permuted back to the original input layout. + In order to swizzle in the target layout, the scales are padded to be divisible by 128 and 4 along the + swizzle and quantization dimensions, respectively. + """ + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + if swizzle_axis is not None: + assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}" + swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim + + L = src_tensor.shape[axis] + if out_quant_type == torch.uint8: + # We make this assertion since we can't track if the "real" shape was odd, and we padded it to be even. + # We want to maintain the property dequant(quant(x)).shape == x.shape + assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}" + + is_fp8 = out_quant_type == torch.float8_e4m3fn or out_quant_type == torch.float8_e5m2 + divisor = 1 if is_fp8 else 2 + device = src_tensor.device + + packed_quant_dim = triton.cdiv(L, divisor) + out_scale_dim = triton.cdiv(L, 32) + + permute_order, scale_permute_order, convert_order = axis_permute_order(ndim, axis, swizzle_axis) + + prmted_quant_tensor_shape = permute_shape(src_tensor.shape, permute_order)[:-1] + (packed_quant_dim,) + prmted_scale_shape = permute_shape(src_tensor.shape, scale_permute_order)[:-1] + (out_scale_dim,) + prmted_src_tensor = src_tensor.permute(permute_order) + + if out_quant_tensor is None: + out_quant_tensor = torch.empty(prmted_quant_tensor_shape, dtype=out_quant_type, device=device) + else: + expected_shape = src_tensor.shape[:axis] + (packed_quant_dim,) + src_tensor.shape[axis + 1:] + assert out_quant_tensor.shape == expected_shape, f"{out_quant_tensor.shape=} != {expected_shape=}" + assert out_quant_tensor.dtype == out_quant_type, f"{out_quant_tensor.dtype=} != {out_quant_type=}" + assert out_quant_tensor.stride(axis) == 1, f"{out_quant_tensor.stride(axis)=} != 1" + # We expect the axis dimension to be last, so permute the tensor + out_quant_tensor = out_quant_tensor.permute(permute_order) + + if out_scale is None: + allocation_shape = prmted_scale_shape + if swizzle_axis is not None: + allocation_shape = list(prmted_scale_shape) + allocation_shape[-1] = triton.cdiv(allocation_shape[-1], SWIZZLE_ALIGN_INNER) * SWIZZLE_ALIGN_INNER + allocation_shape[-2] = triton.cdiv(allocation_shape[-2], SWIZZLE_SIZE_OUTER) * SWIZZLE_SIZE_OUTER + out_scale = torch.empty(allocation_shape, dtype=torch.uint8, device=device) + else: + if swizzle_axis is not None: + expected_scale_shape = list(prmted_scale_shape) + # Pad then unpermute the expected shape + expected_scale_shape[-1] = triton.cdiv(expected_scale_shape[-1], SWIZZLE_ALIGN_INNER) * SWIZZLE_ALIGN_INNER + expected_scale_shape[-2] = triton.cdiv(expected_scale_shape[-2], SWIZZLE_SIZE_OUTER) * SWIZZLE_SIZE_OUTER + expected_scale_shape = permute_shape(expected_scale_shape, scale_permute_order) + else: + expected_scale_shape = permute_shape(prmted_scale_shape, scale_permute_order) + + assert out_scale.shape == expected_scale_shape, f"{out_scale.shape=} {expected_scale_shape=}" + assert out_scale.dtype == torch.uint8, f"{out_scale.dtype=} != torch.uint8" + out_scale = out_scale.permute(scale_permute_order) + + if convert_order or prmted_scale_shape != out_scale.shape: + # Output shape is padded. Make a new unpadded tensor. + assert swizzle_axis is not None # padding only occurs in the swizzled case. + # scales should be produced in `permute_order`. + unpadded_out_scale = torch.empty(transpose_shape(prmted_scale_shape, *convert_order) if convert_order else prmted_scale_shape, dtype=torch.uint8, device=device) + else: + unpadded_out_scale = out_scale + + # Flatten input tensor for kernel. This will typically make a copy + reshaped_src_tensor = prmted_src_tensor.reshape(-1, L) + blocks_quant_dim = triton.cdiv(reshaped_src_tensor.shape[-1], BLOCK_QUANT_DIM) + blocks_out_dim = triton.cdiv(reshaped_src_tensor.shape[0], BLOCK_OUT_DIM) + + # Flatten the output tensors for the kernel, this should be a view always + kernel_quant_tensor = out_quant_tensor.reshape(-1, packed_quant_dim) + kernel_scale = unpadded_out_scale.reshape(-1, out_scale_dim) + assert kernel_quant_tensor.data_ptr() == out_quant_tensor.data_ptr() + assert kernel_scale.data_ptr() == unpadded_out_scale.data_ptr() + + _downcast_to_mxfp[(blocks_out_dim, blocks_quant_dim)]( + kernel_quant_tensor, *kernel_quant_tensor.stride(), + kernel_scale, *kernel_scale.stride(), + reshaped_src_tensor, *reshaped_src_tensor.stride(), + *reshaped_src_tensor.shape, + BLOCK_OUT_DIM, BLOCK_QUANT_DIM, DEQUANT_SCALE_ROUNDING_MODE.value, + num_warps=8 + ) + + if convert_order or prmted_scale_shape != out_scale.shape: + if convert_order: + # convert scales from `permute_order` to `scale_permute_order` + unpadded_out_scale = unpadded_out_scale.transpose(*convert_order) + # Copy from the unpadded shape into the padded one. + out_scale[tuple(slice(0, size) for size in unpadded_out_scale.shape)] = unpadded_out_scale + + # Zero out any padding. `tcgen05.mma` yields MAX_FINITE for the entire block if any + # scales are not finite (0xFF). + slices = [slice(None) for _ in unpadded_out_scale.shape] + for i, size in enumerate(unpadded_out_scale.shape): + slices[i] = slice(size, None) + out_scale[slices] = 0 + slices[i] = slice(None) + + out_quant_tensor = out_quant_tensor.permute(permute_order) + + if swizzle_axis is not None: + out_scale = swizzle_mx(out_scale, allow_pad=False).contiguous().permute(scale_permute_order) + else: + out_scale = out_scale.permute(permute_order).contiguous() + return out_quant_tensor, out_scale, permute_shape(prmted_scale_shape, scale_permute_order) + + +def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int, swizzle_axis: int | None = None, + BLOCK_OUT_DIM: int = 128, BLOCK_QUANT_DIM: int = 32): + """ + Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16. + + The function assumes that the tensors were quantized along the given axis. + It permutes the tensor so that the quantized axis is last, reshapes to 2D, + launches the Triton upcast kernel, and then unpermutes back to the original order. + """ + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + if swizzle_axis is not None: + assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}" + swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim + + multiplier = 1 if "float8" in str(tensor.dtype) else 2 + logical_quant_dim_shape = tensor.shape[axis] * multiplier + assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}") + quant_dim_align = SWIZZLE_ALIGN_INNER if swizzle_axis is not None else 1 + assert triton.cdiv(triton.cdiv(logical_quant_dim_shape, 32), quant_dim_align) * quant_dim_align == scale.shape[axis], \ + f"Tensor and scale mismatch along quantization axis. Got {tensor.shape[axis]=} and {scale.shape[axis]=}" + assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \ + f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in {torch.float16, torch.bfloat16}, f"Invalid output dtype {dtype=}" + + # Bring the quantized axis to the end. + # For the scales, bring the swizzle axis second to last. + permute_order, scale_permute_order, convert_order = axis_permute_order(ndim, axis, swizzle_axis) + prmt_tensor = tensor.permute(permute_order).contiguous() + prmt_scale = scale.permute(scale_permute_order) + + # Unswizzle the scale tensor and slice off padding. + if swizzle_axis is not None: + prmt_scale = unswizzle_mx(prmt_scale) + + unpadded_scale_shape = (*prmt_tensor.shape[:-1], triton.cdiv(logical_quant_dim_shape, 32)) + # The kernel expects scales in `permute_order`, not `scale_permute_order`. Transpose if needed. + if convert_order: + prmt_scale = prmt_scale.transpose(*convert_order) + + slices = tuple(slice(0, size) for size in unpadded_scale_shape) + prmt_scale = prmt_scale[slices] + + prmt_scale = prmt_scale.contiguous() + + quant_dim = prmt_tensor.shape[-1] + reshaped_tensor = prmt_tensor.reshape(-1, quant_dim) + reshaped_scale = prmt_scale.reshape(-1, prmt_scale.shape[-1]) + + outer_dim = reshaped_tensor.shape[0] + blocks_out_dim = triton.cdiv(outer_dim, BLOCK_OUT_DIM) + blocks_quant_dim = triton.cdiv(logical_quant_dim_shape, BLOCK_QUANT_DIM) + + out = torch.empty((outer_dim, logical_quant_dim_shape), dtype=dtype, device=tensor.device) + _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)]( + out, out.stride(0), out.stride(1), + reshaped_scale, reshaped_scale.stride(0), reshaped_scale.stride(1), + reshaped_tensor, reshaped_tensor.stride(0), reshaped_tensor.stride(1), + outer_dim, logical_quant_dim_shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM, num_warps=8 + ) + # Reshape back to the permuted shape. + out = out.view(*prmt_tensor.shape[:-1], logical_quant_dim_shape) + out = out.permute(permute_order) + return out + + +def right_shift_unsigned(x, shift): + # CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift + return (x >> shift) & ((1 << (32 - shift)) - 1) + + +def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int, swizzle_axis: int | None = None, + out_quant_tensor: torch.Tensor | None = None, out_scale: torch.Tensor | None = None, + DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP): + """ + Converts the src tensor to the output format specified by out_quant_type. + axis: The axis along which the tensors are contiguous and quantization is applied. + DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN. + + Returns: + out_quant_tensor: Quantized tensor in mx format. + • For mxfp8, the output has the same shape as src_tensor. + • For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8. + scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis. + Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32), + where L is the original length along that axis. + """ + + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + assert src_tensor.dtype in {torch.float32, torch.bfloat16, torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}" + + axis = axis if axis >= 0 else axis + ndim + if swizzle_axis is not None: + assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}" + swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim + is_fp4 = out_quant_type == torch.uint8 + is_fp8 = "float8" in str(out_quant_type) + assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}" + + device = src_tensor.device + + # For mxfp4 conversion, we assume the contiguous axis length is even. + if is_fp4: + axis_shape = src_tensor.size(axis) + assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even." + + # Permute the tensor so that the contiguous axis becomes the last dimension. + # For the scales, make the swizzle axis is second to last. + permute_order, scale_permute_order, convert_order = axis_permute_order(ndim, axis, swizzle_axis) + src = src_tensor.permute(permute_order).to(torch.float32) # now shape: (..., axis_shape) + axis_shape = src.shape[-1] + + # Pad the axis to be divisible by 32, in case it is not. + next_multiple = (axis_shape + 31) // 32 * 32 + pad_amount = next_multiple - axis_shape + padded_src = F.pad(src, (0, pad_amount)) + valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount)) + padded_axis_shape = padded_src.size(-1) # now divisible by 32 + + # --- Compute per-group maximums for scale --- + # Set padded entries to -1 so they don’t affect the max. + abs_f = torch.abs(padded_src) + abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype)) + # Reshape the last dimension into groups of 32. + new_shape = padded_src.shape[:-1] + (padded_axis_shape // 32, 32) + abs_groups = abs_f.view(*new_shape) + # Compute maximum along the group dimension (of size 32). + max_val, _ = abs_groups.max(dim=-1, keepdim=True) + + # Choose a max quantization value depending on type. + max_quant_val = get_max_quant_val(out_quant_type) + dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1) + + # Convert to int to round the FP32 scale, prior to quantization! + ds_int = dequant_scale.view(torch.int32) + if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP: + ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000 + else: + ds_int_rounded = ds_int & 0x7F800000 + # Reinterpret back as float32. + dequant_scale_rounded = ds_int_rounded.view(torch.float32) + + # Compute the quantization scale. + quant_scale = torch.where(dequant_scale_rounded == 0, + torch.tensor(0.0, device=device), + 1.0 / dequant_scale_rounded) + + # Quantize the tensor + orig_padded_shape = padded_src.shape + padded_src_groups = padded_src.view(*new_shape) + quant_tensor = padded_src_groups * quant_scale + # Reshape back to the original shape and trim padding + quant_tensor = quant_tensor.view(orig_padded_shape) + quant_tensor = quant_tensor[..., :axis_shape] + + # Finally, convert the quantized tensor to the target format + if is_fp8: + # Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior + quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val) + out_weight = quant_tensor.to(out_quant_type) + else: + assert is_fp4, f"Invalid output quantization type {out_quant_type}" + # For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8. + # First, reinterpret the quantized tensor bits. + q_int = quant_tensor.contiguous().view(torch.int32) + # Extract sign, exponent, and mantissa. + signs = q_int & 0x80000000 + exponents = right_shift_unsigned(q_int, 23) & 0xFF + mantissas = q_int & 0x7FFFFF + + E8_BIAS = 127 + E2_BIAS = 1 + # Adjust mantissas for subnormals. + mantissas = torch.where(exponents < E8_BIAS, + (0x400000 | right_shift_unsigned(mantissas, 1)) >> (E8_BIAS - exponents - 1), + mantissas) + exponents = torch.maximum(exponents, + torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS) + e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1) + e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device)) + e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape) + + # Pack pairs of 4-bit values along the last dimension. + e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2) + evens = e2m1_value[..., 0] + odds = e2m1_value[..., 1] + out_weight = evens | (odds << 4) # shape: (..., axis_shape//2) + + # --- Process and output the scale --- + dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1) + dq_scale = dq_scale.squeeze(-1) + + if convert_order: + # dq_scale was produced in `permute_order`, but we want it to be in `scale_permute_order`. + dq_scale = dq_scale.transpose(*convert_order) + + if swizzle_axis is not None: + dq_scale = swizzle_mx(dq_scale) + + # Now, invert the permutation so that the contiguous axis returns to its original position. + out_weight = out_weight.permute(permute_order) + dq_scale = dq_scale.permute(scale_permute_order) + + if out_quant_tensor is not None: + assert out_quant_tensor.shape == out_weight.shape, f"Invalid shape {out_quant_tensor.shape} != {out_weight.shape}" + assert out_quant_tensor.dtype == out_weight.dtype, f"Invalid dtype {out_quant_tensor.dtype} != {out_weight.dtype}" + out_quant_tensor.copy_(out_weight) + else: + out_quant_tensor = out_weight + + if out_scale is not None: + assert out_scale.shape == dq_scale.shape, f"Invalid shape {out_scale.shape} != {dq_scale.shape}" + assert out_scale.dtype == dq_scale.dtype, f"Invalid dtype {out_scale.dtype} != {dq_scale.dtype}" + out_scale.copy_(dq_scale) + else: + out_scale = dq_scale + + return out_quant_tensor, out_scale.contiguous() + + +def cvt_e2m1_to_fp32(input_tensor): + assert input_tensor.dtype == torch.uint8 + + input_tensor = input_tensor.to(torch.int32) + evens = input_tensor & 0xF + odds = (input_tensor >> 4) & 0xF + + vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6] + outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device) + outputs = torch.cat([outputs, -outputs]) + + even_floats = outputs[evens] + odd_floats = outputs[odds] + output_tensor = torch.stack([even_floats, odd_floats], dim=-1) + output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1) + return output_tensor + + +def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int, swizzle_axis: int | None = None): + """ + Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype. + axis: The axis along which dequantization is applied. + + Returns: + out_weight: Tensor in the target format. + """ + + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2 + assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}" + + # Permute the tensor and scale so that the quantization axis becomes the last dimension + # For the scales, also permute so the swizzle axis is second to last. + axis = axis if axis >= 0 else axis + ndim + if swizzle_axis is not None: + assert -ndim <= swizzle_axis < ndim, f"Invalid swizzle axis {swizzle_axis=}" + swizzle_axis = swizzle_axis if swizzle_axis >= 0 else swizzle_axis + ndim + permute_order, scale_permute_order, convert_order = axis_permute_order(ndim, axis, swizzle_axis) + + tensor = tensor.permute(permute_order) + scale = scale.permute(scale_permute_order) + + if swizzle_axis is not None: + scale = unswizzle_mx(scale) + + dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32 + + if is_fp8: + fp32_tensor = tensor.to(torch.float32) + else: + assert tensor.dtype == torch.uint8 + fp32_tensor = cvt_e2m1_to_fp32(tensor) + + fp_tensor_shape = fp32_tensor.shape + if convert_order: + fp_tensor_shape = transpose_shape(fp_tensor_shape, *convert_order) + + # Trim padding + dq_scale = dq_scale[..., :fp_tensor_shape[-2], :(fp_tensor_shape[-1] + 31) // 32] + if convert_order: + dq_scale = dq_scale.transpose(*convert_order) + + axis_shape = fp32_tensor.size(-1) + padded_axis_shape = dq_scale.size(-1) * 32 + pad_size = padded_axis_shape - axis_shape + padded_tensor = F.pad(fp32_tensor, (0, pad_size)) + + new_axis_shape = padded_tensor.shape[-1] + new_shape = padded_tensor.shape[:-1] + (new_axis_shape // 32, 32) + padded_tensor = padded_tensor.view(*new_shape) + dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1] + out_padded = padded_tensor * dq_scale_padded + + # Flatten back and remove the padded tail + out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape) + out_tensor = out_padded[..., :axis_shape] + + out_tensor = out_tensor.permute(permute_order).to(target_dtype) + return out_tensor + + +def swizzle_mx(tensor: torch.Tensor, allow_pad=True): + """ + Swizzle the input tensor of shape (A, B, ... N, K) to (A, B, ... N // 128, K // 4, 32, 4, 4). + Padding is applied if N and K are not multiples of 128 and 4 respectively. + Returns the swizzled tensor repacked as (A, B, ... N, K), with padding. + """ + *leading_shape, N, K, = tensor.shape + pad_k = (SWIZZLE_ALIGN_INNER - (K % SWIZZLE_ALIGN_INNER)) % SWIZZLE_ALIGN_INNER + pad_n = (SWIZZLE_SIZE_OUTER - (N % SWIZZLE_SIZE_OUTER)) % SWIZZLE_SIZE_OUTER + if pad_k or pad_n > 0: + assert allow_pad, "Padding is required for swizzling, but it was explicitly disabled." + tensor = torch.nn.functional.pad(tensor, (0, pad_k, 0, pad_n)) + padded_shape = tensor.shape + tensor = tensor.reshape(*leading_shape, padded_shape[-2] // SWIZZLE_SIZE_OUTER, SWIZZLE_SIZE_OUTER // 32, 32, padded_shape[-1] // SWIZZLE_SIZE_INNER, SWIZZLE_SIZE_INNER) + permute_order = list(range(len(tensor.shape))) + permute_order[-2], permute_order[-4] = permute_order[-4], permute_order[-2] + return tensor.permute(permute_order).reshape(*padded_shape) + + +def unswizzle_mx(tensor: torch.Tensor): + """ + Unswizzle the input tensor of shape (A, B, ... N // 128, K // 4, 32, 4, 4) packed as (A, B, ... N, K). + """ + assert tensor.shape[-1] % SWIZZLE_SIZE_INNER == 0, f"{tensor.shape[-1]=} must be a multiple of {SWIZZLE_SIZE_INNER}" + assert tensor.shape[-2] % SWIZZLE_SIZE_OUTER == 0, f"{tensor.shape[-2]=} must be a multiple of {SWIZZLE_SIZE_OUTER}" + *leading_shape, N, K, = tensor.shape + tensor = tensor.reshape(*leading_shape, N // SWIZZLE_SIZE_OUTER, K // SWIZZLE_SIZE_INNER, 32, SWIZZLE_SIZE_OUTER // 32, SWIZZLE_SIZE_INNER) + permute_order = list(range(len(tensor.shape))) + permute_order[-2], permute_order[-4] = permute_order[-4], permute_order[-2] + return tensor.permute(permute_order).reshape(*leading_shape, N, K) diff --git a/kernel-microbench/tk/triton_kernels/reduction.py b/kernel-microbench/tk/triton_kernels/reduction.py new file mode 100644 index 0000000..1d0a2e0 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/reduction.py @@ -0,0 +1,16 @@ +import torch +import triton +from .reduction_details.reduce_bitmatrix import sum_bitmatrix_rows +from . import Bitmatrix + + +def sum(x, partials_block_size=None, dim=0): + cdiv = triton.cdiv + assert isinstance(x, Bitmatrix) + assert dim == 0 + assert partials_block_size is not None + n_rows, n_cols = x.shape + dev = x.data.device + out_ret = torch.empty(n_cols, dtype=torch.int32, device=dev) + out_partials = torch.empty((cdiv(n_rows, partials_block_size), n_cols), dtype=torch.int32, device=dev) + return sum_bitmatrix_rows(x, out_ret, out_partials, partials_block_size) diff --git a/kernel-microbench/tk/triton_kernels/reduction_details/reduce_bitmatrix.py b/kernel-microbench/tk/triton_kernels/reduction_details/reduce_bitmatrix.py new file mode 100644 index 0000000..4cdfafe --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/reduction_details/reduce_bitmatrix.py @@ -0,0 +1,90 @@ +import triton +import triton.language as tl + + +@triton.jit +def vpopc(x): + """ + Vertical popcount + Input x : uint32[..., N] + Output y : uint32[..., 32] + semantics : y[..., i] = sum_j((x[..., j] >> i) & 1) + credits: @apgoucher + """ + + tl.static_assert(x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers") + + BLOCK_N: tl.constexpr = x.shape[-1] # summation axis + BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches + if BLOCK_N >= 8: + sa1: tl.constexpr = 8 + else: + sa1: tl.constexpr = BLOCK_N + # create 8-way sums in 4-bit fields: + y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1]) + y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111 + y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4] + if BLOCK_N >= 128: + sa2: tl.constexpr = 16 + else: + sa2: tl.constexpr = BLOCK_N // sa1 + # create 128-way sums in 8-bit fields: + y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4]) + y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0f0f0f0f + y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4] + sa3: tl.constexpr = BLOCK_N // (sa1 * sa2) + # create N-way sums in 32-bit fields: + y = tl.reshape(y, [BATCHES, 1, sa3, 8]) + y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000ff + y = tl.sum(y, 2) # [BATCHES, 4, 8] + y = tl.reshape(y, x.shape[:-1] + [32]) + return y + + +@triton.jit +def _sum_bitmatrix_memset(Ret, ret_size, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + tl.store(Ret + offs, 0, mask=offs < ret_size) + + +@triton.jit +def _sum_bitmatrix_rows(B, shape_bm, stride_bm, # input bitmatrix + Ret, Partials, stride_pm, shape_pn, # outputs + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + tl.static_assert(BLOCK_N % 32 == 0) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + BLOCK_B: tl.constexpr = BLOCK_N // 32 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_b = pid_n * BLOCK_B + tl.arange(0, BLOCK_B) + bits = tl.load(B + offs_m[None, :] * stride_bm + offs_b[:, None], mask=offs_m[None, :] < shape_bm) + ret = tl.reshape(vpopc(bits), [BLOCK_N]) + mask = offs_n < shape_pn + tl.atomic_add(Ret + offs_n, ret, mask=mask) + tl.store(Partials + pid_m * stride_pm + offs_n, ret, mask=mask) + + +def sum_bitmatrix_rows(x, out_ret, out_partials, partials_block_size=None): + assert partials_block_size is not None + cdiv = triton.cdiv + PARTIALS_BLOCK_M = partials_block_size + BLOCK_N = 32 + MEMSET_BLOCK = 512 + n_rows, n_cols = x.shape + assert out_ret.shape == (n_cols, ) + assert out_partials.shape == (cdiv(n_rows, PARTIALS_BLOCK_M), n_cols) + # output tensors + _sum_bitmatrix_memset[(cdiv(out_ret.shape[0], MEMSET_BLOCK), )]( + out_ret, out_ret.shape[0], # outputs + BLOCK=512 # tunable parameter + ) + _sum_bitmatrix_rows[(cdiv(n_rows, PARTIALS_BLOCK_M), cdiv(n_cols, BLOCK_N))]( + x.data, x.data.shape[0], x.data.stride(0), # input + out_ret, # output [final reduction] + out_partials, out_partials.stride(0), out_partials.shape[1], # output [partial reductions] + BLOCK_N=BLOCK_N, # tunable parameters + BLOCK_M=PARTIALS_BLOCK_M, # constants + ) + return out_ret, out_partials diff --git a/kernel-microbench/tk/triton_kernels/routing.py b/kernel-microbench/tk/triton_kernels/routing.py new file mode 100644 index 0000000..94ca0d0 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/routing.py @@ -0,0 +1,140 @@ +import torch +import triton +from dataclasses import dataclass, field +from .routing_details._routing_compute import _routing_memset_indx +from .routing_details._routing_compute import _routing_compute_indx_offs +from .routing_details._routing_compute import _routing_compute_indx +from .routing_details._routing_compute import _routing_clear_bitmatrix + + +@dataclass +class GatherIndx: + """ + Indices for an operation that performs: + Y = X[src_idx, :] + """ + # array such that `dst_idx[src_idx] = arange(0, N)` + src_indx: torch.Tensor + dst_indx: torch.Tensor + + +@dataclass +class ScatterIndx: + """ + Indices for an operation that performs: + Y[dst_idx, :] = X + """ + # array such that `dst_idx[src_idx] = arange(0, N)` + src_indx: torch.Tensor + dst_indx: torch.Tensor + + +@dataclass +class RoutingData: + gate_scal: torch.Tensor = field() + expt_hist: torch.Tensor = field() + n_expts_tot: int = field() + n_expts_act: int = field() + + # Used to make perf annotation cleaner: when we use expert sharding, we can + # use this to tell the "expected" number of local tokens per expert, because + # the actual number can vary per each input. + expected_tokens_per_expt: int = field(default=None) + + def n_blocks(self, n_rows, block_m): + if n_rows <= self.n_expts_tot: + return n_rows + else: + return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1 + + +# -------------------------- +# Triton routing +# -------------------------- + + +def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1): + from .topk import topk + from .reduction import sum + from .compaction import compaction + assert expt_indx is None + cdiv = triton.cdiv + HIST_BLOCK_M = 64 + INDX_OFFS_BLOCK_M = 512 + MEMSET_BLOCK = 1024 + assert logits.dtype.itemsize == 2 + n_tokens, n_expts_tot = logits.shape + n_gates = n_tokens * n_expts_act + device = logits.device + expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act) + # mutate bitmatrix + if simulated_ep > 1: + assert n_expts_tot % simulated_ep == 0 + _routing_clear_bitmatrix[(n_tokens, )]( + bitmatrix.data, + bitmatrix.data.stride(0), + bitmatrix.data.shape[1], + n_expts_tot // simulated_ep, + BLOCK_N=512, + ) + expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix) + n_expts_tot = n_expts_tot // simulated_ep + bitmatrix.shape[-1] = n_expts_tot + # perform compaction to update expt_scal / expt_indx + hist, partial_hist = sum(bitmatrix, partials_block_size=HIST_BLOCK_M, dim=0) + # scratchpad + expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device) + indx_offs = torch.empty((cdiv(n_tokens, HIST_BLOCK_M), n_expts_tot), dtype=torch.int32, device=device) + combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device) + # output + topk_indx = combined_indx[:n_gates] + gate_indx = combined_indx[n_gates:] + gate_scal = torch.empty(n_gates, dtype=logits.dtype, device=device) + _routing_memset_indx[(cdiv(n_gates * 2, MEMSET_BLOCK) + 1, )](combined_indx, n_gates * 2, -1, MEMSET_BLOCK, hist, + expt_offs, hist.shape[0], BLOCK_N=512) + _routing_compute_indx_offs[(n_expts_tot, )]( + expt_offs, partial_hist, # inputs + indx_offs, partial_hist.shape[0], partial_hist.stride(0), # outputs + BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters + ) + _routing_compute_indx[(cdiv(n_tokens, HIST_BLOCK_M), )]( + topk_indx, gate_indx, gate_scal, # outputs + expt_scal, expt_indx, indx_offs, indx_offs.stride(0), n_gates, # input + BLOCK_M=HIST_BLOCK_M, # tunable parameters + N_EXPTS_ACT=n_expts_act, # constants + num_warps=1 if HIST_BLOCK_M * n_expts_act // 32 < 4 else 4) + # pack the matmul data structure + gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx) + scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx) + return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act), gather_indx, scatter_indx + + +def routing_torch(logits, n_expts_act, expt_indx=None): + + def topk(vals, k, expt_indx): + # topk of experts + if expt_indx is None: + tk_idx = torch.argsort(-vals, dim=1, stable=True)[:, :k] + else: + tk_idx = expt_indx + tk_val = torch.take_along_dim(vals, tk_idx, dim=1) + return tk_val, tk_idx + + _, n_expts_tot = logits.shape + expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx) + expt_scal = torch.softmax(expt_scal, dim=-1) + # Sort each token's selections by expert + expt_indx, sort_indices = torch.sort(expt_indx, dim=1) + expt_scal = torch.gather(expt_scal, 1, sort_indices) + # flatten topk data + expt_scal = expt_scal.reshape(-1) + expt_indx = expt_indx.reshape(-1).to(torch.int32) + # sort by expert_id so experts are contiguous for the matmul + topk_indx = torch.argsort(expt_indx, stable=True) + gate_indx = torch.argsort(topk_indx) + gate_scal = expt_scal[topk_indx] + hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1) # histogram of tokens over experts + # pack the matmul data structure + gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) + scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) + return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act), gather_indx, scatter_indx diff --git a/kernel-microbench/tk/triton_kernels/routing_details/_routing_compute.py b/kernel-microbench/tk/triton_kernels/routing_details/_routing_compute.py new file mode 100644 index 0000000..3e45a23 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/routing_details/_routing_compute.py @@ -0,0 +1,108 @@ +import triton +import triton.language as tl + + +@triton.jit +def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histogram + BLOCK_N: tl.constexpr): + loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N + x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty) + for i in range(loop_iterations): + offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < hist_size + hist2 = tl.load(ExpertHist + offs_n, mask=mask_n) + tok_starts = tl.cumsum(hist2, 0) - hist2 + x + x += tl.sum(hist2, 0) + tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n) + offs_n += BLOCK_N + + +@triton.jit +def _routing_compute_indx_offs(TokensStart, PartialHist, PartialOffs, shape_pm, stride_pm, BLOCK_M: tl.constexpr): + expt_id = tl.program_id(0) + offs_m = tl.arange(0, BLOCK_M) + # initialize first row of the output + start = tl.load(TokensStart + expt_id) + tl.store(PartialOffs + expt_id, start) + # iterate over input data + curr_sum = start + for _ in range(0, shape_pm, BLOCK_M): + offs = offs_m * stride_pm + expt_id + curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm) + out = tl.cumsum(curr, 0) + curr_sum + curr_sum += tl.sum(curr, 0) + offs = (1 + offs_m) * stride_pm + expt_id + tl.store(PartialOffs + offs, out, mask=offs_m < shape_pm - 1) + offs_m += BLOCK_M + + +@triton.jit +def _keyed_add(x, y): + + # we keep the key in the upper 16 bits of a uint32: + key_mask: tl.constexpr = 0xffff0000 + + kx = x & key_mask + ky = y & key_mask + z = tl.where(kx == ky, x + y - kx, y) + return z + + +@triton.jit +def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, n_gates, + BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr): + + pid_m = tl.program_id(0) + + tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768) + + local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M) + offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs + expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32) + + # stable-sort by expert ID: + kv_pairs = ((expert << 16) | local_offs).to(tl.uint32) + kv_pairs = tl.sort(kv_pairs, 0) + expert = kv_pairs >> 16 + offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff) + mask = expert != 0xffff + gate_scal = tl.load(ExptScal + offs, mask=mask) + + # compute run lengths in expert-sorted order: + x = (kv_pairs & 0xffff0000 | 0x00000001) + expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add) + exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff + + gates = tl.load(PartialOffs + pid_m * stride_pm + expert, mask=(expert != 0xffff)) + gates += exclusive_run_lengths + + tl.store(ScatterIndx + offs, gates, mask=mask) + tl.store(GatherIndx + gates, offs, mask=mask) + tl.store(GateScal + gates, gate_scal, mask=mask) + + +@triton.jit +def _routing_clear_bitmatrix(Bitmatrix, stride_bm, shape_bn, cutoff, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + cutoff_word = cutoff // 32 + cutoff_bit = cutoff % 32 + cutoff_mask = (1 << (cutoff_bit)) - 1 + for start_n in range(0, shape_bn, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n, mask=offs_n < shape_bn) + values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values) + values = tl.where(offs_n > cutoff_word, 0, values) + tl.store(Bitmatrix + pid_m * stride_bm + offs_n, values, mask=offs_n < shape_bn) + + +@triton.jit +def _routing_memset_indx(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size, + BLOCK_N: tl.constexpr): + pid = tl.program_id(0) + + if pid == 0: + _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N) + else: + offs = (pid - 1) * BLOCK + tl.arange(0, BLOCK) + mask = offs < size + tl.store(Indx + offs, sentinel, mask=mask) diff --git a/kernel-microbench/tk/triton_kernels/specialize.py b/kernel-microbench/tk/triton_kernels/specialize.py new file mode 100644 index 0000000..f3b8ba7 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/specialize.py @@ -0,0 +1,98 @@ +import inspect +import re +import textwrap +import types +import triton + + +def define_kernel(src, module, attrs=None, **extra_globals): + """ + Dynamically create a Triton function or kernel from a src string, + linking any symbols in the kernel to objects specified by extra_globals. + """ + + # create templace function + def _empty_fn(): + pass + + gdict = dict(**(_empty_fn.__globals__)) + gdict.update(extra_globals) + f = types.FunctionType(_empty_fn.__code__, gdict) + f.__module__ = module.__name__ + + src = textwrap.dedent(src) + src = src[src.find("def "):] + + stored_functions = [] + function_name = src[4:].split("(")[0].strip() + + exec_globals = gdict + exec_globals.update({"stored_functions": stored_functions}) + exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals) + + f.__signature__ = inspect.signature(stored_functions[0]) + f.__name__ = function_name + f.__doc__ = stored_functions[0].__doc__ + + if attrs is None: + attrs = dict() + f = triton.JITFunction(f, **attrs) + f._unsafe_update_src(src) + return f + + +def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()): + assert isinstance(fn, triton.runtime.jit.JITFunction) + if name is None: + name = f"{fn.__name__}" + # Get original source code + src = inspect.getsource(fn.fn) + src = textwrap.dedent(src) + lines = src.split("\n") + # Skip decorator and def line + def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def")) + # separate header vs body LOC + header_end = def_idx + while not lines[header_end].rstrip().endswith(":"): + header_end += 1 + body_lines = lines[header_end + 1:] + header_lines = lines[def_idx:header_end + 1] + # clean-up header + header_clean = [ + l.split("#", 1)[0].strip() # keep code, discard comment + for l in header_lines + if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines + ] + # decompose arguments + header_src = " ".join(header_clean) # turn it into a single line + m = re.search(r"\((.*)\)\s*:", header_src) + if not m: + raise ValueError("Could not parse function header") + args_str = m.group(1) + args = [arg.strip() for arg in args_str.split(",") if arg.strip()] + non_specialized_args = [] + for arg in args: + arg_key = arg.split(":")[0].split("=")[0].strip() + new_args = tuples.get(arg_key, [arg]) + if arg_key not in constants: + non_specialized_args += new_args + # add global symbols + extra_globals = {v.__name__: v for k, v in constants.items() if isinstance(v, triton.runtime.jit.JITFunction)} + extra_globals.update(fn.__globals__) + # build new source code and define kernel dynamically + new_signature = f"def {name}({', '.join(non_specialized_args)}):" + constexpr_lines = [ + f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}" for key, value in constants.items() + ] + tuple_lines = [ + f" {key} = {'(' + ','.join(value) + (',' if len(value)>=1 else '') + ')'}" for key, value in tuples.items() + ] + new_src = "\n".join(["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines) + # find function parameters + sig = inspect.signature(triton.runtime.jit.JITFunction.__init__) + params = list(sig.parameters.values())[2:] + attrs = {param.name: getattr(fn, param.name, param.default) for param in params} + if do_not_specialize: + attrs["do_not_specialize"] = do_not_specialize + ret = define_kernel(new_src, module, attrs, **extra_globals) + return ret diff --git a/kernel-microbench/tk/triton_kernels/swiglu.py b/kernel-microbench/tk/triton_kernels/swiglu.py new file mode 100644 index 0000000..66d6dd5 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/swiglu.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass +from triton_kernels.numerics import InFlexData, OutFlexData +import torch +import triton +from .swiglu_details._swiglu import _swiglu +from triton_kernels import target_info +from .matmul_ogs_details.metadata import compute_metadata + + +@dataclass(frozen=True) +class FlexCtx: + out_data: OutFlexData = OutFlexData() + inp_data: InFlexData = InFlexData() + saturate_inf: bool = False + + +@dataclass(frozen=True) +class PrecisionConfig: + limit: float + flex_ctx: FlexCtx = FlexCtx() + + +class SwiGLU(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, alpha, precision_config, routing_data): + N = a.shape[-1] + M = a.numel() // N + assert a.stride()[-1] == 1 + assert a.shape[-1] % 2 == 0 + out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device) + flex_ctx = precision_config.flex_ctx + # optimization hyperparameters + BLOCK_M, BLOCK_N = 32 // a.itemsize, 128 + num_warps = 4 + kwargs = {'maxnreg': 64} if not target_info.is_hip() else {} + # launch semi-persistent kernel + N_BLOCKS = triton.cdiv(N // 2, BLOCK_N) + num_sms = target_info.num_sms() + if routing_data is not None: + waves_per_sm = 32 if target_info.is_hip() else 128 + num_pid = num_sms * (waves_per_sm // num_warps) + M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS)) + grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), ) + else: + M_BLOCKS = triton.cdiv(M, BLOCK_M) + if M_BLOCKS * N_BLOCKS >= 8 * num_sms: + grid = (8 * num_sms, ) + else: + grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), ) + n_tokens = None + if routing_data is not None: + n_tokens = compute_metadata(routing_data, M, BLOCK_M).offs[routing_data.n_expts_tot] + _swiglu[grid]( + flex_ctx.out_data.reinterpret(out), + flex_ctx.out_data.expected_scale, + flex_ctx.out_data.actual_scale, + flex_ctx.out_data.checksum_scale, + flex_ctx.inp_data.reinterpret(a), + flex_ctx.inp_data.scale, + alpha, + M, + N // 2, + a.shape[-1], + 1, + out.shape[-1], + 1, + precision_config.limit, + n_tokens, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + EVEN_N=(N // 2) % BLOCK_N == 0, + M_BLOCKS=M_BLOCKS, + N_BLOCKS=N_BLOCKS, + flexpoint_saturate_inf=flex_ctx.saturate_inf, + num_warps=num_warps, + **kwargs, + ) + out = out.view(a.shape[:-1] + out.shape[-1:]) + return out + + +def swiglu(a, alpha, precision_config, routing_data=None): + return SwiGLU.apply(a, alpha, precision_config, routing_data) + + +def swiglu_torch(a, alpha, precision_config): + limit = precision_config.limit + a_gelu = a[..., ::2] + if limit is not None: + a_gelu = a_gelu.clamp(max=limit) + a_linear = a[..., 1::2] + if limit is not None: + a_linear = a_linear.clamp(min=-limit, max=limit) + + out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) + out = out_gelu * (a_linear + 1) + return out diff --git a/kernel-microbench/tk/triton_kernels/swiglu_details/_swiglu.py b/kernel-microbench/tk/triton_kernels/swiglu_details/_swiglu.py new file mode 100644 index 0000000..751b6c9 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/swiglu_details/_swiglu.py @@ -0,0 +1,94 @@ +from triton_kernels.numerics_details.flexpoint import load_scale, float_to_flex, update_scale +import triton +import triton.language as tl + + +@triton.jit +def clip(x, limit, clip_lower: tl.constexpr): + res = tl.minimum(x, limit) + if clip_lower: + res = tl.maximum(-limit, res) + return res + + +@triton.jit +def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr): + return tl.max(tl.reshape(tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True), axis=1) + + +def swiglu_repr(specialization): + signature = specialization.signature + constants = specialization.constants + convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype + dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]]) + blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]]) + return f"_swiglu_{dtypes}_{blocks}" + + +def swiglu_launch_metadata(grid, kernel, args): + M, N = args["M"], args["N"] + ret = dict() + ret["name"] = f"{kernel.name} [M = {M}, N = {N}]" + A, Out = args["A"], args["Out"] + ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size() + return ret + + +@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata) +def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an, + stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr): + if NTokens is not None: + M = tl.load(NTokens) + M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M + + local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32) + + a_scale = load_scale(AScale) + out_expected_scale = load_scale(OutExpectedScale) + + for pid in tl.range(tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2): + pid_m = (pid // N_BLOCKS) + pid_n = (pid % N_BLOCKS) + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = off_m < M + mask_n = off_n < N + packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2 + packed_mask_n = packed_off_n < N + packed_mask_n = tl.max_constancy(packed_mask_n, [16]) + # load a + packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N) + packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an + if EVEN_N: + a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.) + else: + if pid_n * BLOCK_N + BLOCK_N <= N: + a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.) + else: + packed_mask = mask_m[:, None] and packed_mask_n[None, :] + a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.) + a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2))) + # a gelu + a_gelu = a_gelu.to(tl.float32) * a_scale + if limit is not None: + a_gelu = clip(a_gelu, limit, clip_lower=False) + # a linear + a_linear = a_linear.to(tl.float32) * a_scale + if limit is not None: + a_linear = clip(a_linear, limit, clip_lower=True) + # compute output + s = a_gelu / (1 + tl.exp(-alpha * a_gelu)) + out = tl.fma(s, a_linear, s) # (s * (a_linear + 1)) + # update flexpoint stats and divide by scale + # we don't need masking because of the `other` when loading `A` + if OutActualScale is not None: + absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads()) + local_max = tl.maximum(local_max, absmax) + out = float_to_flex(out, out_expected_scale, + None, # ActualScale: local absmax is tracked and updated after the loop + OutChecksumScale, None, Out, flexpoint_saturate_inf) + mask = mask_m[:, None] if EVEN_N else mask_m[:, None] and mask_n[None, :] + tl.store(Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask) + + update_scale(local_max, OutActualScale, Out) diff --git a/kernel-microbench/tk/triton_kernels/target_info.py b/kernel-microbench/tk/triton_kernels/target_info.py new file mode 100644 index 0000000..1d3286d --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/target_info.py @@ -0,0 +1,46 @@ +import torch +import triton + +cached_capabilities = {} + + +def is_hip(): + if "is_hip" not in cached_capabilities: + cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip) + return cached_capabilities["is_hip"] + + +def cuda_capability_geq(major, minor=0): + """ + Determines whether we have compute capability >= (major, minor) and + returns this as a constexpr boolean. This can be used for guarding + inline asm implementations that require a certain compute capability. + """ + if is_hip(): + return False + if "cuda" not in cached_capabilities: + if torch.cuda.is_available(): + cached_capabilities["cuda"] = torch.cuda.get_device_capability() + else: + cached_capabilities["cuda"] = (0, 0) + return cached_capabilities["cuda"] >= (major, minor) + + +def get_cdna_version(): + """ + Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently + only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD + hardware or unsupported architecture + """ + target = triton.runtime.driver.active.get_current_target() + if target.backend != 'hip': + return -1 + if target.arch == 'gfx942': + return 3 + if target.arch == 'gfx950': + return 4 + return -1 + + +def num_sms(): + return torch.cuda.get_device_properties(0).multi_processor_count diff --git a/kernel-microbench/tk/triton_kernels/testing.py b/kernel-microbench/tk/triton_kernels/testing.py new file mode 100644 index 0000000..d905725 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/testing.py @@ -0,0 +1,192 @@ +import enum +import functools +import os +import subprocess +import sys +import torch +from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 + + +def assert_equal(ref, tri): + if isinstance(ref, torch.Tensor): + assert torch.all(ref == tri) + else: + assert ref == tri + + +def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): + if tri.dtype.itemsize == 1: + ref_as_type = ref.to(tri.dtype) + if ref.dtype == tri.dtype: + assert torch.all(ref_as_type == tri) + return + ref = ref_as_type + + if maxtol is None: + maxtol = 2e-2 + if rmstol is None: + rmstol = 4e-3 + """ + Compare reference values against obtained values. + """ + + # cast to float32: + ref = ref.to(torch.float32).detach() + tri = tri.to(torch.float32).detach() + assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}" + + # deal with infinite elements: + inf_mask_ref = torch.isinf(ref) + inf_mask_tri = torch.isinf(tri) + assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements" + refn = torch.where(inf_mask_ref, 0, ref) + trin = torch.where(inf_mask_tri, 0, tri) + + # normalise so that RMS calculation doesn't overflow: + eps = 1.0e-30 + multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) + refn *= multiplier + trin *= multiplier + + ref_rms = torch.sqrt(torch.square(refn).mean()) + eps + + rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) + max_err = torch.max(rel_err).item() + rms_err = torch.sqrt(torch.square(rel_err).mean()).item() + + if verbose: + print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol)) + print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol)) + + if max_err > maxtol: + bad_idxs = torch.nonzero(rel_err > maxtol) + num_nonzero = bad_idxs.size(0) + bad_idxs = bad_idxs[:1000] + print("%d / %d mismatched elements (shape = %s) at coords %s" % + (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())) + + bad_idxs = bad_idxs.unbind(-1) + print("ref values: ", ref[tuple(bad_idxs)].cpu()) + print("tri values: ", tri[tuple(bad_idxs)].cpu()) + + assert max_err <= maxtol + assert rms_err <= rmstol + + +class ComputeSanitizerTool(enum.Enum): + MEMCHECK = "memcheck" + RACECHECK = "racecheck" + SYNCCHECK = "synccheck" + INITCHECK = "initcheck" + + +def compute_sanitizer(**target_kwargs): + """ + Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled, + to expose potential memory access errors. + This decorator requires the `request` fixture to be present. + If `run_sanitizer` argument is present and set to False, the sanitizer is not run. + Running tests under compute sanitizer requires launching subprocess and is slow, + so use sparingly + """ + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1": + test_fn(*args, **kwargs) + return + + import psutil + + if target_kwargs.pop("clear_torch_cache", False): + # If we don't pop clear_torch_cache, it won't pass + # target_kwargs.items() <= kwargs.items() condition below. + torch.cuda.empty_cache() + tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK]) + assert isinstance(tools_to_check, list), f"{tools_to_check=}" + assert all(tool in ComputeSanitizerTool for tool in tools_to_check), ( + f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}") + + ppid_name = psutil.Process(os.getppid()).exe() + run_compute_sanitizer = target_kwargs.items() <= kwargs.items() + if "run_sanitizer" in kwargs: + run_compute_sanitizer &= kwargs["run_sanitizer"] + if run_compute_sanitizer and "compute-sanitizer" not in ppid_name: + for tool in tools_to_check: + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = { + "PATH": os.environ["PATH"], + "PYTORCH_NO_CUDA_MEMORY_CACHING": "1", + "TORCH_SHOW_CPP_STACKTRACES": "1", + "CUDA_LAUNCH_BLOCKING": "1", + } + if "CUDA_VISIBLE_DEVICES" in os.environ: + env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] + assert "request_fixture" in kwargs, ( + "memcheck'ed test must have a (possibly unused) `request` fixture") + test_id = kwargs["request_fixture"].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + cmd = [ + "compute-sanitizer", + "--target-processes=application-only", + "--destroy-on-device-error=context", + f"--tool={tool.value}", + sys.executable, + "-m", + "pytest", + "-vsx", + cmd, + ] + for opt in ["--update_checksum", "--ignore_checksum_error"]: + if opt in sys.argv: + cmd.append(opt) + out = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + ) + sanitizer_ok = "ERROR SUMMARY: 0 errors" in str( + out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout) + test_output = out.stdout + if type(test_output) is bytes: + test_output = test_output.decode() + + fail = False + if not sanitizer_ok: + print("compute-sanitizer returned an error") + fail = True + elif out.returncode != 0: + print( + "The test failed due to some other reason: consider running without compute-sanitizer to verify." + ) + print(f"{out.returncode=}") + fail = True + + if fail: + print("*****************************************************") + print("******************** TEST OUTPUT ********************") + print("*****************************************************") + print(test_output) + print("*****************************************************") + print("****************** TEST OUTPUT END ******************") + print("*****************************************************") + assert None + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +def compute_actual_scale(x, dtype): + max_finite = { + torch.float8_e5m2: MAX_FINITE_FLOAT8E5, + torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV, + torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8, + }[dtype] + return x.abs().max() / max_finite diff --git a/kernel-microbench/tk/triton_kernels/topk.py b/kernel-microbench/tk/triton_kernels/topk.py new file mode 100644 index 0000000..edbad32 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/topk.py @@ -0,0 +1,32 @@ +import torch +from .topk_details._topk import _topk +from triton_kernels import Bitmatrix + + +def topk(x, k, dim=1, return_bitmatrix=True): + cdiv = lambda a, b: (a + b - 1) // b + BLOCK_M = 8 + BLOCK_N = 128 + assert x.dtype.itemsize == 2 + assert x.ndim == 2 + assert x.shape[-1] < 32768 + assert dim == 1 + assert return_bitmatrix + n_rows, n_cols = x.shape + dev = x.device + n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N + n_cols_words = n_cols_pad // 32 + # scratchpad tensors + # NOTE: these are not returned + y_vals = torch.empty((n_rows, k), dtype=x.dtype, device=dev) + y_indx = torch.empty((n_rows, k), dtype=torch.int16, device=dev) + bitmatrix = torch.empty((n_rows, n_cols_words), dtype=torch.uint32, device=dev) + _topk[(cdiv(n_rows, BLOCK_M), )]( + x, x.stride(0), # inputs + y_vals, y_indx, y_vals.stride(0), # output [topk] + bitmatrix, bitmatrix.stride(0), # output [bitmatrix] + n_rows, n_cols, # shapes + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter + N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants + ) + return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows, n_cols]) diff --git a/kernel-microbench/tk/triton_kernels/topk_details/_topk.py b/kernel-microbench/tk/triton_kernels/topk_details/_topk.py new file mode 100644 index 0000000..8603c31 --- /dev/null +++ b/kernel-microbench/tk/triton_kernels/topk_details/_topk.py @@ -0,0 +1,76 @@ +import triton +import triton.language as tl + + +@triton.jit +def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + BLOCK_N: tl.constexpr): + + # subtract 1 from loop iterations because we peel the first (masked) iteration: + loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1 + + offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_x_n[None, :] < n_expts_tot + + # first iteration: + X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :] + x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf")) + x = (x.to(tl.uint16, bitcast=True).to(tl.int32) << 16) | offs_x_n[None, :] + x = x.to(tl.float32, bitcast=True) + + acc = tl.topk(x, N_EXPTS_ACT, dim=1) + + # subsequent iterations: + for _i in range(loop_iterations): + acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge + X_ptrs -= BLOCK_N + offs_x_n -= BLOCK_N + x = tl.load(X_ptrs, mask=mask_m, other=float("-inf")) + x = (x.to(tl.uint16, bitcast=True).to(tl.int32) << 16) | offs_x_n[None, :] + x = x.to(tl.float32, bitcast=True) + acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1)) + + return acc + + +@triton.jit +def _topk(X, stride_xm, # inputs + Yv, Yi, stride_ym, # topk values/indices + Bits, stride_rm, n_rows, # bitmatrix + n_expts_tot, BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + BLOCK_N: tl.constexpr): + + tl.static_assert(BLOCK_N % 32 == 0) + tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0) + x_dtype: tl.constexpr = X.dtype.element_ty + + # load logits + offs_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m[:, None] < n_rows + y = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N) + y = y.to(tl.uint32, bitcast=True) + + # sort result in direction of ascending expert index + y = (y << 16) | (y >> 16) + y = tl.sort(y, dim=1) + y_indices = y >> 16 + y_values = (y & 0x0000FFFF).to(tl.uint16).to(x_dtype, bitcast=True) + y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype) + + # write back + offs_y_n = tl.arange(0, N_EXPTS_ACT) + Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :] + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :] + tl.store(Yv_ptrs, y_values, mask=mask_m) + tl.store(Yi_ptrs, y_indices, mask=mask_m) + + # pack into bitmatrix + y_div = y_indices // 32 + y_rem = y_indices % 32 + loop_iterations = N_EXPTS_PAD // BLOCK_N + for i in range(loop_iterations): + offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) + y2 = tl.where(y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0) + r = tl.reduce_or(y2, axis=1) + BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] + tl.store(BitsPtrs, r, mask=mask_m) diff --git a/microbench.sh b/microbench.sh old mode 100644 new mode 100755 index e426a76..395c430 --- a/microbench.sh +++ b/microbench.sh @@ -17,7 +17,9 @@ do then if [ "$kernel" == "triton" ] then - cmd="time proton -k triton microbench.py --workload $workload --profiler $profiler --kernel $kernel" + # combine workload and kernel into a single argument + profile_name="$workload-$kernel" + cmd="time proton -k triton -n $profile_name microbench.py --workload $workload --profiler $profiler --kernel $kernel" else cmd="time proton microbench.py --workload $workload --profiler $profiler --kernel $kernel" fi diff --git a/requirements_hopper.txt b/requirements_hopper.txt new file mode 100644 index 0000000..5d0c55f --- /dev/null +++ b/requirements_hopper.txt @@ -0,0 +1,71 @@ +accelerate==1.6.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.11.18 +aiosignal==1.3.2 +attrs==25.3.0 +bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_aarch64.whl#sha256=cccb2afa70c311d088127ad02bbf161010dff7cc3abf421d3b8b76718b021741 +certifi==2025.4.26 +charset-normalizer==3.4.2 +cmake==3.31.6 +cut-cross-entropy==25.1.1 +datasets==3.6.0 +dill==0.3.8 +docstring_parser==0.16 +filelock==3.18.0 +frozenlist==1.6.0 +fsspec==2025.3.0 +hf_transfer==0.1.9 +huggingface-hub==0.31.2 +idna==3.10 +Jinja2==3.1.6 +liger_kernel==0.5.9 +lit==18.1.8 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +mdurl==0.1.2 +mpmath==1.3.0 +msgspec==0.19.0 +multidict==6.4.3 +multiprocess==0.70.16 +networkx==3.4.2 +ninja==1.11.1.4 +numpy==2.2.5 +packaging==25.0 +pandas==2.2.3 +peft==0.15.2 +pillow==11.2.1 +propcache==0.3.1 +protobuf==3.20.3 +psutil==7.0.0 +pyarrow==20.0.0 +pybind11==2.13.6 +Pygments==2.19.1 +python-dateutil==2.9.0.post0 +pytz==2025.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rich==14.0.0 +safetensors==0.5.3 +sentencepiece==0.2.0 +setuptools==80.4.0 +shtab==1.7.2 +six==1.17.0 +sympy==1.13.1 +tokenizers==0.21.1 +torch @ https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl#sha256=993e0e99c472df1d2746c3233ef8e88d992904fe75b8996a2c15439c43ff46c4 +tqdm==4.67.1 +transformers==4.51.3 +triton @ file:///home/jlee436/triton/python +trl==0.15.2 +typeguard==4.4.2 +typing_extensions==4.13.2 +tyro==0.9.20 +tzdata==2025.2 +unsloth @ git+https://github.com/unslothai/unsloth.git@3f03c7250d137abe98cda89abf9f17cf78a70bb7 +unsloth_zoo==2025.5.5 +urllib3==2.4.0 +wheel==0.45.1 +xformers @ git+https://github.com/facebookresearch/xformers.git@1298453cf117c63dd691c55925fe1f41d3c874d6 +xxhash==3.5.0 +yarl==1.20.0