From a803f4d88a5e4c31e34ccce183a4c33306ec69db Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 6 Aug 2025 11:29:38 +0000 Subject: [PATCH 1/9] Initial sanity check implementation --- ...plit_table_batched_embeddings_benchmark.py | 247 +++++++++++++----- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 127 +++++++-- 2 files changed, 289 insertions(+), 85 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..614e32cf15 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -8,11 +8,13 @@ # pyre-strict +import gzip import logging import os import tempfile from contextlib import nullcontext -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional +import yaml import click import numpy as np @@ -1011,7 +1013,31 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options +@click.option("--batch-size", default=512) +@click.option("--embedding-dim-list", type=str, default="128") +@click.option("--weights-precision", type=SparseType, default=SparseType.FP32) +@click.option("--cache-precision", type=SparseType, default=None) +@click.option("--stoc", is_flag=True, default=False) +@click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) +@click.option("--managed", default="device") +@click.option("--num-embeddings-list", type=str, default="100000") +@click.option("--reuse", default=0.0) +@click.option("--row-wise/--no-row-wise", default=True) +@click.option("--weighted", is_flag=True, default=False) +@click.option("--pooling", type=str, default="sum") +@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) +@click.option("--flush-gpu-cache-size-mb", default=0) +@click.option("--output-dtype", type=SparseType, default=SparseType.FP32) +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) +@click.pass_context def device_with_spec( # noqa C901 + ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1031,7 +1057,40 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, ) -> None: + if load: + with open(f"{load}/params.yaml", "r") as f: + ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) + alpha = ctx.params["alpha"] + bag_size_list = ctx.params["bag_size_list"] + bag_size_sigma_list = ctx.params["bag_size_sigma_list"] + batch_size = ctx.params["batch_size"] + embedding_dim_list = ctx.params["embedding_dim_list"] + weights_precision = ctx.params["weights_precision"] + cache_precision = ctx.params["cache_precision"] + stoc = ctx.params["stoc"] + iters = ctx.params["iters"] + warmup_runs = ctx.params["warmup_runs"] + managed = ctx.params["managed"] + num_embeddings_list = ctx.params["num_embeddings_list"] + reuse = ctx.params["reuse"] + row_wise = ctx.params["row_wise"] + weighted = ctx.params["weighted"] + pooling = ctx.params["pooling"] + bounds_check_mode = ctx.params["bounds_check_mode"] + flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] + output_dtype = ctx.params["output_dtype"] + random_weights = ctx.params["random_weights"] + compressed = ctx.params["compressed"] + slice_min = ctx.params["slice_min"] + slice_max = ctx.params["slice_max"] + np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1040,6 +1099,12 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" + + params = ctx.params + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1118,6 +1183,22 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) + elif random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") + + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1130,52 +1211,68 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - - del all_requests + if load: + requests = [] + for i in range(iters): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") + requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) + else: + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + del all_requests + assert len(requests) == iters + if save: + for i in range(iters): + req = requests[i] + torch.save(req.indices, f"{save}/{i}_indices.pt") + torch.save(req.offsets, f"{save}/{i}_offsets.pt") + torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") + torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: @@ -1201,36 +1298,44 @@ def device_with_spec( # noqa C901 f"Accessed weights per batch: {B * sum_DLs * param_size_multiplier / 1.0e9: .2f} GB" ) + if load is None and save is None: # forward - time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) - logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + time_per_iter = benchmark_requests( + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) + logging.info( + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative return - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if load: + grad_output = torch.load(f"{load}/grad_output.pt") else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + + if save: + torch.save(grad_output, f"{save}/grad_output.pt") + # backward time_per_iter = benchmark_requests( requests, @@ -1244,6 +1349,12 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1243f14db4..7ccaad95e1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,6 +11,7 @@ import statistics import threading import time +import gzip from subprocess import Popen from typing import Callable, Optional @@ -18,6 +19,8 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device +from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen + logging.basicConfig(level=logging.DEBUG) @@ -241,35 +244,43 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - + import copy if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + if not (load or save): + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters + sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() @@ -279,6 +290,88 @@ def benchmark_requests( # noqa: C901 start_events = [] end_events = [] + if save and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + if compressed: + with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: + torch.save(out, f) + else: + torch.save(out, f"{save}/{it}_fwd_grad_out.pt") + + out.backward(grad) + torch.cuda.synchronize() + torch.save(out, f"{save}/{it}_bwd_grad_out.pt") + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) + print("PASS") + + for it in range(iters): req = requests[it % num_reqs] From 6e49e19873b4ad6d53ab40def08d8f5ffb17c90a Mon Sep 17 00:00:00 2001 From: jichen Date: Fri, 24 Oct 2025 08:36:00 +0000 Subject: [PATCH 2/9] add fwd sanity check --- ...plit_table_batched_embeddings_benchmark.py | 274 ++++++++++-------- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 124 +++++++- 2 files changed, 264 insertions(+), 134 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 614e32cf15..2d3755fe06 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -1367,19 +1367,19 @@ def device_with_spec( # noqa C901 @click.option( "--batch-size-list", type=str, - required=True, + required=False, help="A comma separated list of batch sizes (B) for each table.", ) @click.option( "--embedding-dim-list", type=str, - required=True, + required=False, help="A comma separated list of embedding dimensions (D) for each table.", ) @click.option( "--bag-size-list", type=str, - required=True, + required=False, help="A comma separated list of bag sizes (L) for each table.", ) @click.option( @@ -1392,7 +1392,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-embeddings-list", type=str, - required=True, + required=False, help="A comma separated list of number of embeddings (E) for each table.", ) @click.option( @@ -1405,7 +1405,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-tables", type=int, - required=True, + required=False, help="The number of tables.", ) @click.option( @@ -1414,16 +1414,12 @@ def device_with_spec( # noqa C901 default=False, help="Whether the table is weighted or not", ) -@click.option( - "--print-kernel-summary", - is_flag=True, - default=False, - help="Whether the table is weighted or not", -) -@click.option("--ssd", is_flag=True, default=False) -@click.option( - "--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix" -) +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) @TBEBenchmarkingConfigLoader.options @EmbeddingOpsCommonConfigLoader.options @click.pass_context @@ -1437,9 +1433,12 @@ def vbe( alpha_list: str, num_tables: int, weighted: bool, - print_kernel_summary: bool, - ssd: bool, - ssd_prefix: str, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, # pyre-ignore[2] **kwargs, ) -> None: @@ -1451,6 +1450,28 @@ def vbe( np.random.seed(42) torch.manual_seed(42) + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(context.params, f, sort_keys=False) + + if load: + with open(f"{load}/params.yaml", "r") as f: + context.params = yaml.load(f, Loader=yaml.UnsafeLoader) + params = context.params + batch_size_list = params["batch_size_list"] + embedding_dim_list = params["embedding_dim_list"] + bag_size_list = params["bag_size_list"] + bag_size_sigma_list = params["bag_size_sigma_list"] + num_embeddings_list = params["num_embeddings_list"] + alpha_list = params["alpha_list"] + num_tables = params["num_tables"] + weighted = params["weighted"] + random_weights = params["random_weights"] + compressed = params["compressed"] + slice_min = params["slice_min"] + slice_max = params["slice_max"] + # Load general TBE benchmarking configuration from cli arguments benchconfig = TBEBenchmarkingConfigLoader.load(context) if benchconfig.num_requests != benchconfig.iterations: @@ -1459,6 +1480,9 @@ def vbe( if benchconfig.flush_gpu_cache_size_mb != 0: raise ValueError("--bench-flush-gpu-cache-size is not supported.") + if benchconfig.export_trace: + raise ValueError("--bench-export-trace is not supported.") + # Load common embedding op configuration from cli arguments embconfig = EmbeddingOpsCommonConfigLoader.load(context) if embconfig.uvm_host_mapped: @@ -1495,126 +1519,122 @@ def vbe( else EmbeddingLocation.HOST ) - common_split_args: dict[str, Any] = { - "weights_precision": embconfig.weights_dtype, - "stochastic_rounding": embconfig.stochastic_rounding, - "output_dtype": embconfig.output_dtype, - "pooling_mode": embconfig.pooling_mode, - "bounds_check_mode": embconfig.bounds_check_mode, - "optimizer": optimizer, - "learning_rate": 0.1, - "eps": 0.1, - "feature_table_map": list(range(T)), - } - - if ssd: - cache_set = max(T * max(Bs), 1) - tempdir = tempfile.mkdtemp(prefix=ssd_prefix) - emb = SSDTableBatchedEmbeddingBags( - [(E, D) for E, D in zip(Es, Ds)], - cache_sets=cache_set, - ssd_storage_directory=tempdir, - ssd_cache_location=EmbeddingLocation.DEVICE, - ssd_rocksdb_shards=8, - **common_split_args, - ) - else: - emb = SplitTableBatchedEmbeddingBagsCodegen( - [ - ( - E, - D, - managed_option, - get_available_compute_device(), - ) - for E, D in zip(Es, Ds) - ], - cache_precision=embconfig.cache_dtype, - **common_split_args, - ) - emb = emb.to(get_device()) - all_requests = { - "indices": [[] for _ in range(benchconfig.iterations)], - "offsets": [[] for _ in range(benchconfig.iterations)], - "weights": [[] for _ in range(benchconfig.iterations)], - } - for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): - # Generate a request for a single table. - local_requests = generate_requests( - benchconfig.iterations, - B, - 1, - L, - E, - alpha=alpha, - weighted=weighted, - sigma_L=sigma_L, - zipf_oversample_ratio=3 if L > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + E, + D, + managed_option, + get_available_compute_device(), + ) + for E, D in zip(Es, Ds) + ], + optimizer=optimizer, + learning_rate=0.1, + eps=0.1, + cache_precision=embconfig.cache_dtype, + weights_precision=embconfig.weights_dtype, + stochastic_rounding=embconfig.stochastic_rounding, + output_dtype=embconfig.output_dtype, + pooling_mode=embconfig.pooling_mode, + bounds_check_mode=embconfig.bounds_check_mode, + ).to(get_device()) + + if random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) - # Store requests for each table in all_requests. - for i, req in enumerate(local_requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") - # pyre-ignore[53] - def _kineto_trace_handler( - p: profile, emb_op_type: str = "vbe", print_summary: bool = False - ) -> None: - p.export_chrome_trace( - benchconfig.trace_url.format(emb_op_type=emb_op_type, ospid=os.getpid()) - ) - if print_summary: - print(p.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) - emb_op_type = "vbe" + if load: + requests = [] + for i in range(benchconfig.iterations): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + requests.append((indices, offsets, per_sample_weights)) + else: + all_requests = { + "indices": [[] for _ in range(benchconfig.iterations)], + "offsets": [[] for _ in range(benchconfig.iterations)], + "weights": [[] for _ in range(benchconfig.iterations)], + } + for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): + # Generate a request for a single table. + local_requests = generate_requests( + benchconfig.iterations, + B, + 1, + L, + E, + alpha=alpha, + weighted=weighted, + sigma_L=sigma_L, + zipf_oversample_ratio=3 if L > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) - # pyre-ignore[3, 53] - def context_factory(on_trace_ready: Callable[[profile], None]): - return ( - profile(on_trace_ready=on_trace_ready) - if benchconfig.export_trace - else nullcontext() - ) + # Store requests for each table in all_requests. + for i, req in enumerate(local_requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) - # Combine the requests for all tables by - requests = [ - ( - torch.concat(all_requests["indices"][i]), - torch.concat(all_requests["offsets"][i]), - torch.concat(all_requests["weights"][i]) if weighted else None, - ) - for i in range(benchconfig.iterations) - ] + # Combine the requests for all tables by + requests = [ + ( + torch.concat(all_requests["indices"][i]), + torch.concat(all_requests["offsets"][i]), + torch.concat(all_requests["weights"][i]) if weighted else None, + ) + for i in range(benchconfig.iterations) + ] + + del all_requests - del all_requests + if save: + for i, (indices, offsets, weights) in enumerate(requests): + torch.save(indices, f"{save}/{i}_indices.pt") + torch.save(offsets, f"{save}/{i}_offsets.pt") + torch.save(weights, f"{save}/{i}_per_sample_weights.pt") - with context_factory( - lambda p: _kineto_trace_handler(p, emb_op_type, print_kernel_summary) - ): - fwd_time_sec, bwd_time_sec = benchmark_vbe( - requests, - func=lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - batch_size_per_feature_per_rank=[[B] for B in Bs], - ), - num_warmups=benchconfig.warmup_iterations, - ) + fwd_time_sec, bwd_time_sec = benchmark_vbe( + requests, + func=lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank=[[B] for B in Bs], + ), + num_warmups=benchconfig.warmup_iterations, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, + ) logging.info( f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n" f"fwd: {fwd_time_sec * 1.0e6:.0f}us, bwd: {bwd_time_sec * 1.0e6:.0f}us" ) - if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 7ccaad95e1..da502d1c21 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -21,6 +21,7 @@ from fbgemm_gpu.tbe.utils.common import get_device from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen +import copy logging.basicConfig(level=logging.DEBUG) @@ -334,6 +335,12 @@ def benchmark_requests( # noqa: C901 out = emb(indices, offsets, weights) torch.cuda.synchronize() + out_ref = torch.load(f"{load}/{it}_fwd_grad_out.pt") + torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) + + print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) + print("FWD PASS") + out.backward(grad) torch.cuda.synchronize() emb_ref = copy.deepcopy(emb) @@ -695,6 +702,12 @@ def benchmark_vbe( requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], num_warmups: int = 0, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> tuple[float, float]: """ A benchmark function to return the average execution time in seconds of @@ -719,14 +732,16 @@ def benchmark_vbe( """ use_cuda = torch.cuda.is_available() + sliced = slice_min is not None and slice_max is not None + if not (load or save): # Warm-ups. - for _ in range(num_warmups): - # Warm-up using the first request as done in benchmark_requests - indices, offsets, weights = requests[0] - out = func(indices, offsets, weights) - grad = torch.rand_like(out) - out.backward(grad) + for _ in range(num_warmups): + # Warm-up using the first request as done in benchmark_requests + indices, offsets, weights = requests[0] + out = func(indices, offsets, weights) + grad = torch.rand_like(out) + out.backward(grad) iters = len(requests) if use_cuda: @@ -740,6 +755,101 @@ def benchmark_vbe( fwd_times_sec = [] bwd_times_sec = [] + if save and emb: + for it, req in enumerate(requests): + + indices, offsets, weights = req + out = func(indices, offsets, weights) + torch.cuda.synchronize() + + torch.save(out, f"{save}/{it}_fwd_out.pt") + + grad = torch.rand_like(out) + if compressed: + with gzip.open(f"{save}/{it}_grad.pt.gz", "wb") as f: + torch.save(grad, f) + else: + torch.save(grad, f"{save}/{it}_grad.pt") + + out.backward(grad) + torch.cuda.synchronize() + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it, req in enumerate(requests): + + indices, offsets, weights = req + out = func(indices, offsets, weights) + torch.cuda.synchronize() + + out_ref = torch.load(f"{load}/{it}_fwd_out.pt") + torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) + + print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) + print("FWD PASS") + + if compressed: + with gzip.open(f"{load}/{it}_grad.pt.gz", "rb") as f: + grad = torch.load(f) + else: + grad = torch.load(f"{load}/{it}_grad.pt") + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) + print("PASS") + + for i, (indices, offsets, weights) in enumerate(requests): # forward if use_cuda: @@ -792,4 +902,4 @@ def benchmark_vbe( # pyre-ignore[61] bwd_time_sec = statistics.median(bwd_times_sec) - return fwd_time_sec, bwd_time_sec + return fwd_time_sec, bwd_time_sec \ No newline at end of file From fb40487c5784408874d36154188828e0c946d417 Mon Sep 17 00:00:00 2001 From: yadai Date: Tue, 28 Oct 2025 02:51:15 +0000 Subject: [PATCH 3/9] update --- ...ing_backward_split_kernel_warp_template.cu | 105 +++++++--- .../embedding_backward_split_template.cu | 15 +- ...optimizer_split_device_kernel_template.cuh | 183 +++++++++++++++++- 3 files changed, 267 insertions(+), 36 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index e26d1834aa..d1f715511d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -417,6 +417,8 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} const auto start_run_id = blockIdx.x * blockDim.y + threadIdx.y; +#define SUBWARP_SHFL_SYNC(val, srcLane) __shfl_sync(UINT64_MAX, val, srcLane, kThreadGroupSize) + #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = ((1L << kThreadGroupSize) - 1) << @@ -467,39 +469,53 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc int32_t s_table_unique_indice_offset = is_valid? table_unique_indices_offsets[s_t_0] : 0; int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; - int64_t s_momentum1_offset = is_valid? momentum1_offsets[s_t_0] : 0; int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; - int32_t s_momentum1_placement = is_valid? momentum1_placements[s_t_0] : 0; - at::acc_type* __restrict__ s_momentum1; - if (static_cast(s_momentum1_placement) == PlacementType::DEVICE) { - s_momentum1 = &momentum1_dev[s_momentum1_offset]; - } else { - s_momentum1 = &momentum1_uvm[s_momentum1_offset]; - } + {# + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + const auto s_{{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[t]; + {%- endfor %} + #} + + // at::acc_type* __restrict__ s_momentum1; + // if (static_cast(s_momentum1_placement) == PlacementType::DEVICE) { + // s_momentum1 = &momentum1_dev[s_momentum1_offset]; + // } else { + // s_momentum1 = &momentum1_uvm[s_momentum1_offset]; + // } for (auto i = 0; i < num_valid_id; ++i) { - auto run_id = out_run_id + i; - auto t_0 = BROADCAST(s_t_0, i); - auto idx = BROADCAST(s_idx, i); - auto segment_start = BROADCAST(s_segment_start, i); - auto segment_end = BROADCAST(s_segment_end, i); - auto D = BROADCAST(s_D, i); - int32_t table_unique_indice_offset = BROADCAST(s_table_unique_indice_offset, i); + auto segment_start = SUBWARP_SHFL_SYNC(s_segment_start, i); + auto segment_end = SUBWARP_SHFL_SYNC(s_segment_end, i); const int32_t SL = segment_end - segment_start; - - const int64_t weights_offset = SHFL_SYNC(s_weights_offset, i); - const auto weights_placement = static_cast(SHFL_SYNC(s_weights_placement, i)); - - const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); - const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); - auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); - auto momentum1_val = momentum1[idx]; - if (SL >= max_segment_length_per_warp) { continue; } + auto run_id = out_run_id + i; + auto t_0 = SUBWARP_SHFL_SYNC(s_t_0, i); + auto idx = SUBWARP_SHFL_SYNC(s_idx, i); + + {%- if not nobag %} + auto D = SUBWARP_SHFL_SYNC(s_D, i); + {%- endif %} + int32_t table_unique_indice_offset = SUBWARP_SHFL_SYNC(s_table_unique_indice_offset, i); + + {# + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + const auto s_{{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[t]; + {%- endfor %} + #} + + // const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); + // const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); + // auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); + // auto momentum1_val = momentum1[idx]; + // now, each segment corresponds to exactly one table `t` and row in // that table (`idx`). Thus, we can hoist out some of the book-keeping. @@ -549,7 +565,11 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc // when kUseVecBlocking == false const int32_t max_vecs = kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; - split_rowwise_adagrad_table_update_kernel< + + {%- if not dense and optimizer != "none" %} + const int64_t weights_offset = SUBWARP_SHFL_SYNC(s_weights_offset, i); + const int32_t weights_placement = SUBWARP_SHFL_SYNC(s_weights_placement, i); + {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, {%- for ph_name in args.placeholder_tensor_names %} @@ -585,8 +605,37 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} shfl_sync_mask, max_vecs, - momentum1, momentum1_val, learning_rate, eps, weight_decay, weight_decay_mode, max_norm + {%- if ssd %} + enable_optimizer_offloading, + {%- endif %} + {{ args.split_kernel_arg_names | join(", ") }} ); + {%- else %} + // Write deduplicated gradient to grad_dev_weights gradient is sparse + // for split_embedding and dense for dense_embedding + {%- if dense %} + const int64_t weights_offset = weights_offsets[t_0]; + {%- else %} + // Compute offset of sparse gradient + const int64_t weights_offset = run_id * max_D; + idx = 0; + {%- endif %} + store_grad_sum< + emb_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_dev_weights, + grad_sum, + kUseVecBlocking ? smem_grad_sum : nullptr, + D, + weights_offset, + idx, + max_vecs + ); + {%- endif %} // if not dense and optimizer != "none" } } } @@ -868,7 +917,7 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif @@ -1118,7 +1167,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endmacro %} {%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} - {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for grad_type in ['float', 'at::Half'] %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 8ff3e56ce4..0dd9b91937 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1212,7 +1212,7 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {%- if is_optimized_hip_kernel_supported_mode %} - if (use_hip_kernel && mixed_D) { + if (false) { backward_cta_per_row_kernel = {{ hip_mixed_d_cta_kernel }} ; {%- if is_optimized_hip_kernel_supported_mode %} - if (use_hip_kernel && mixed_D) { + if (true) { + printf("%s:%d call here\n", __FILE__, __LINE__); backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} ; + 1, + 32, + false>; } {%- endif %} @@ -1383,7 +1384,8 @@ Tensor {{ embedding_cuda_op }}( used_shared_bytes); } - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + // auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + auto blockSize = dim3(32, num_warp_per_row_groups); int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), @@ -1427,6 +1429,7 @@ Tensor {{ embedding_cuda_op }}( {%- endif %} #endif + FBGEMM_LAUNCH_KERNEL( backward_warp_per_row_kernel, warp_per_row_grid_size, diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index e4fb6c548c..a18b97133e 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -11,8 +11,34 @@ #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" -#define GROUP_REDUCE_ALL_SUM(val, ...) \ - warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) +template +DEVICE_INLINE __device__ T subwarp_reduce_add(T value) { + static_assert(kThreadGroupSize == 8 || kThreadGroupSize == 16 || kThreadGroupSize == 32 || kThreadGroupSize == 64, "Wavefront size must be 16/32/64"); + if (kThreadGroupSize == 16) { + // Reduce across 4 groups of 16 threads + value += __shfl_xor(value, 1, 16); + value += __shfl_xor(value, 2, 16); + value += __shfl_xor(value, 4, 16); + value += __shfl_xor(value, 8, 16); + } else if (kThreadGroupSize == 32) { + // Reduce across 2 groups of 32 threads + value += __shfl_xor(value, 1, 32); + value += __shfl_xor(value, 2, 32); + value += __shfl_xor(value, 4, 32); + value += __shfl_xor(value, 8, 32); + value += __shfl_xor(value, 16, 32); + } else if (kThreadGroupSize == 64) { + value += __shfl_xor(value, 1, 64); + value += __shfl_xor(value, 2, 64); + value += __shfl_xor(value, 4, 64); + value += __shfl_xor(value, 8, 64); + value += __shfl_xor(value, 16, 64); + value += __shfl_xor(value, 32, 64); + } + return value; +} + +#define GROUP_REDUCE_ALL_SUM(val, ...) subwarp_reduce_add(val) {%- set mdesc = "ssd" if ssd else "split" %} {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} @@ -176,4 +202,157 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } +template < + typename emb_t, + typename cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {%- set ph_type = "{}_ph_t".format(ph_name) %} + typename {{ ph_type }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize = kWarpSize, + int32_t VEC_WIDTH, + bool kUseVecBlocking +> +DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( + pta::PackedTensorAccessor64& dev_weights, + pta::PackedTensorAccessor64& uvm_weights, + pta::PackedTensorAccessor64& lxu_cache_weights, + const int32_t weights_placement, + const int64_t weights_offset, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, + Vec4TAcc* grad_sum, + Vec4TAcc* smem_grad_sum, + Vec4TAcc* shared_weight_update_row, + const bool stochastic_rounding, + const at::PhiloxCudaState& stochastic_rounding_philox_args, + const uint32_t run_id, + const uint32_t cache_loc_run_id, + const int32_t D, + const int32_t t, + const int64_t idx, + {%- if has_global_weight_decay_support %} + const float global_weight_decay, + {%- endif %} + const uint32_t shfl_sync_mask, + const int32_t max_vecs_per_thread, + {%- if ssd %} + const bool enable_optimizer_offloading, + {%- endif %} + {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} +) { + constexpr auto kIsInt8 = std::is_same_v; + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + emb_t* __restrict__ weights {nullptr}; + cache_t* __restrict__ cache_weights {nullptr}; + int32_t D_emb = D; + if constexpr (kIsInt8) { + D_emb += kINT8QparamsBytes; + } + if (static_cast(weights_placement) == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset + idx * D_emb]; + } else { + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset + idx * D_emb]" }}; + } + if (static_cast(weights_placement) == PlacementType::MANAGED_CACHING) { + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; + } + {%- endif %} + } + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; + if ({{ tensor }}_placement == PlacementType::DEVICE) { + {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; + } else { + {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; + } + {%- endfor %} + + auto weight_row_template = + WeightRow>( + weights, + cache_weights, + D, + stochastic_rounding, + &stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x); + + float2 qparams_template; + if constexpr (kIsInt8) { + if (!cache_weights) { + qparams_template = weight_row_template.load_qparams(); + } + } + + {%- if not ssd %} + [[maybe_unused]] constexpr auto enable_optimizer_offloading = false; + {%- endif %} + + {{ split_precomputation }} + + {# /* Note: technically, global weight decay (gwd) compensation should be done before + `split_precomputation`). But since decouple mode in `rowwise_adagrad` only computes correction, + the order of applying gwd does not matter. We perform gwd update before `split_weight_update` + below to minimize number of times to load weights. + So, note that the behavior may be different if you want to enable gwd for other optimizers + such as `lamb` or `partial_rowwise_lamb`. + */#} + float2 qparams_new; + {{ + generate_optimized_grad_sum_loop_access( + """ + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + Vec4TAcc& grad = {grad_vec}; + {global_weight_decay_update} + {split_weight_update} + if (kIsInt8 && !cache_weights) { + shared_weight_update_row[d_vec] = weight_new; + } else { + // qparams_new not used if type is not int8 + weight_row_template.store(weight_new, d, qparams_new); + } + """, + other_formats={ + "split_weight_update": split_weight_update, + "global_weight_decay_update": "weight_new.mul_(global_weight_decay);" if has_global_weight_decay_support else "" + }, + ) + }} + + if constexpr (kIsInt8) { + if (!cache_weights) { + // Calculate new qparams after row update + qparams_new = thrust_find_qparams>( + shared_weight_update_row, D); + weight_row_template.store_qparams(qparams_new); + + // Fetch cached updated row from shared mem and quantize on-the-fly + // when saving to lowp embedding + for (int32_t vec = 0; + (vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const auto d_vec = vec * kThreadGroupSize + threadIdx.x; + const int32_t d = d_vec * VEC_WIDTH; + weight_row_template.store( + shared_weight_update_row[d_vec], + d, + qparams_new); + } + } + } + + {{ split_post_update }} +} + // clang-format on From 52791a686d97e56fab3d7a7edb250100686d74d2 Mon Sep 17 00:00:00 2001 From: yadai Date: Tue, 28 Oct 2025 03:50:36 +0000 Subject: [PATCH 4/9] update --- ...ing_backward_split_kernel_warp_template.cu | 21 +++++++++---------- .../embedding_backward_split_template.cu | 4 ++-- ...optimizer_split_device_kernel_template.cuh | 10 ++++++--- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index d1f715511d..6e41109a88 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -471,13 +471,11 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; - {# {%- for tensor in args.split_tensors %} {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; - const auto s_{{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); - const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[t]; + const auto s_{{ tensor }}_placement = {{ tensor }}_placements[s_t_0]; + const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[s_t_0]; {%- endfor %} - #} // at::acc_type* __restrict__ s_momentum1; // if (static_cast(s_momentum1_placement) == PlacementType::DEVICE) { @@ -503,13 +501,10 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} int32_t table_unique_indice_offset = SUBWARP_SHFL_SYNC(s_table_unique_indice_offset, i); - {# {%- for tensor in args.split_tensors %} - {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; - const auto s_{{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); - const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[t]; + const auto {{ tensor }}_placement = SUBWARP_SHFL_SYNC(s_{{ tensor }}_placement, i); + const int64_t {{ tensor }}_offset = SUBWARP_SHFL_SYNC(s_{{ tensor }}_offset, i); {%- endfor %} - #} // const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); // const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); @@ -582,8 +577,8 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc dev_weights, uvm_weights, lxu_cache_weights, - weights_placements, - weights_offsets, + weights_placement, + weights_offset, sorted_{{ locs_or_addrs_tensor }}, grad_sum, smem_grad_sum, @@ -608,6 +603,10 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- if ssd %} enable_optimizer_offloading, {%- endif %} + {%- for tensor in args.split_tensors %} + {{ tensor }}_placement, + {{ tensor }}_offset, + {%- endfor %} {{ args.split_kernel_arg_names | join(", ") }} ); {%- else %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 0dd9b91937..f316944253 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1391,6 +1391,7 @@ Tensor {{ embedding_cuda_op }}( div_round_up(total_unique_indices, num_warp_per_row_groups), get_max_thread_blocks_()); +{# #ifdef USE_ROCM {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} @@ -1428,8 +1429,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} #endif - - +#} FBGEMM_LAUNCH_KERNEL( backward_warp_per_row_kernel, warp_per_row_grid_size, diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index a18b97133e..75910068c2 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -239,6 +239,10 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {%- if ssd %} const bool enable_optimizer_offloading, {%- endif %} + {%- for tensor in args.split_tensors %} + const int32_t {{ tensor }}_placement, + const int64_t {{ tensor }}_offset, + {%- endfor %} {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} ) { constexpr auto kIsInt8 = std::is_same_v; @@ -270,9 +274,9 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( } {%- for tensor in args.split_tensors %} {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; - const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); - const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; - if ({{ tensor }}_placement == PlacementType::DEVICE) { + // const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + // const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; + if (static_cast({{ tensor }}_placement) == PlacementType::DEVICE) { {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; } else { {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; From 5856368b0e6ca674ec9221d0217825aa4766db89 Mon Sep 17 00:00:00 2001 From: yadai Date: Tue, 28 Oct 2025 06:17:03 +0000 Subject: [PATCH 5/9] update --- .../embedding_backward_split_kernel_warp_template.cu | 8 ++++---- .../backward/embedding_backward_split_template.cu | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 6e41109a88..cf48896b5d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -1166,10 +1166,10 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endmacro %} {%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} - {%- for grad_type in ['float', 'at::Half'] %} - {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} - {%- for cache_type in ['float', 'at::Half'] %} - {%- for index_type in ['int32_t', 'int64_t'] %} + {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for emb_type in (['float', 'at::Half', 'at::BFloat16'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} + {%- for cache_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for index_type in ['int32_t', 'int64_t', 'at::BFloat16'] %} {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index f316944253..edc3b629e5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1391,7 +1391,6 @@ Tensor {{ embedding_cuda_op }}( div_round_up(total_unique_indices, num_warp_per_row_groups), get_max_thread_blocks_()); -{# #ifdef USE_ROCM {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} @@ -1429,7 +1428,6 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} #endif -#} FBGEMM_LAUNCH_KERNEL( backward_warp_per_row_kernel, warp_per_row_grid_size, From 676414b5ca3488faf5fc7e8011a95242ade9ef82 Mon Sep 17 00:00:00 2001 From: yadai Date: Tue, 28 Oct 2025 08:16:36 +0000 Subject: [PATCH 6/9] update --- ...dding_backward_split_kernel_cta_template.cu | 2 +- .../embedding_backward_split_template.cu | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 25f7119a7a..b10eb1312e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -625,7 +625,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index edc3b629e5..61b7f47662 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1212,9 +1212,9 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {%- if is_optimized_hip_kernel_supported_mode %} - if (false) { + if (!kUseVecBlocking) { backward_cta_per_row_kernel = - {{ hip_mixed_d_cta_kernel }} + {{ cta_kernel }} ; + 1, + 32, + false>; } {%- endif %} @@ -1247,7 +1247,7 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(kThreadGroupSize, num_cta_per_row_groups), + dim3(32, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1350,7 +1350,7 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {%- if is_optimized_hip_kernel_supported_mode %} - if (true) { + if (!kUseVecBlocking) { printf("%s:%d call here\n", __FILE__, __LINE__); backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} @@ -1384,8 +1384,8 @@ Tensor {{ embedding_cuda_op }}( used_shared_bytes); } - // auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - auto blockSize = dim3(32, num_warp_per_row_groups); + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + // auto blockSize = dim3(32, num_warp_per_row_groups); int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), From a9a04eddcd4f1d0e3d4da5bfc4f0ea811dc36f7f Mon Sep 17 00:00:00 2001 From: Wulley Date: Tue, 28 Oct 2025 08:59:14 +0000 Subject: [PATCH 7/9] update subwarp kernel --- ...ing_backward_split_kernel_warp_template.cu | 1 + .../embedding_backward_split_template.cu | 52 +++++++++++++------ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index cf48896b5d..63f0894715 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -37,6 +37,7 @@ not dense and not is_index_select and not is_gwd_kernel and + not nobag and not vbe and not ssd %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 61b7f47662..baada5342c 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -53,7 +53,8 @@ using namespace fbgemm_gpu; not dense and not is_index_select and not is_gwd_kernel and - not vbe and + not vbe and + not nobag and not ssd %} template < @@ -1211,8 +1212,10 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; {%- if is_optimized_hip_kernel_supported_mode %} - if (!kUseVecBlocking) { + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); + if (max_D <= 128) { backward_cta_per_row_kernel = {{ cta_kernel }} ; + + auto cta_blockSize = dim3(32, num_cta_per_row_groups); } + {%- else %} + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); {%- endif %} // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1247,7 +1253,7 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(32, num_cta_per_row_groups), + cta_blockSize, cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1349,9 +1355,10 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; {%- if is_optimized_hip_kernel_supported_mode %} - if (!kUseVecBlocking) { - printf("%s:%d call here\n", __FILE__, __LINE__); + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + if (use_hip_kernel && mixed_D) { backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} ; + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking>; + if (max_D <= 128) { + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + + blockSize = dim3(32, num_warp_per_row_groups); + } } - {%- endif %} - + {%- else %} // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { @@ -1383,10 +1407,6 @@ Tensor {{ embedding_cuda_op }}( backward_warp_per_row_kernel, used_shared_bytes); } - - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - // auto blockSize = dim3(32, num_warp_per_row_groups); - int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), get_max_thread_blocks_()); From 1c6b9b44c76bd693149163a7fef426255258707c Mon Sep 17 00:00:00 2001 From: xzhu Date: Mon, 27 Oct 2025 03:02:34 +0000 Subject: [PATCH 8/9] grad sum kernel unroll improvement --- ..._backward_split_device_kernel_template.cuh | 144 +++++++++++++----- 1 file changed, 106 insertions(+), 38 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..d58f67bcb0 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -14,6 +14,98 @@ using namespace fbgemm_gpu; +// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1) +#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i); +#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i); +#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i); +#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i); +#define IDX_WEIGHT(i, j) at::acc_type idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i); + +#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j); +#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j); +#define REPEAT_2(X, j) X(1, j); +#define REPEAT_1(X, j) // No additional variables needed for block size 1 + +#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n); +#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); +#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n); +#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1 + +// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1) +// if nobag and is_index_select +#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]); +// elif nobag +#define GRAD_VEC_N(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[l_j_##i][d]); +// elif vbe +#define GRAD_VEC_V(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]); +// else +#define GRAD_VEC(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d); + +// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1) +#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i); +// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1) +#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i); + +// Core macro: Process blocks of specified size (block_size = 8/4/2/1) +// Parameters: +// - block_size: Size of each block to process +// - unroll_count: Number of unroll iterations for the inner loop +#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \ + for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \ + {%- if nobag %} + int32_t l_j_0 = SHFL_SYNC(l, j); \ + REPEAT_##block_size(L, j) \ + {%- elif vbe %} + /* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \ + const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \ + /* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \ + REPEAT_##block_size(GRAD_OFFSET, j) \ + {%- else %} + int32_t b_j_0 = SHFL_SYNC(b, j); \ + REPEAT_##block_size(B, j) \ + int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \ + REPEAT_##block_size(D_START, j) \ + {%- endif %} + {%- if weighted %} + at::acc_type idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \ + REPEAT_##block_size(IDX_WEIGHT, j) \ + {%- endif %} + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + \ + for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \ + const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \ + /* Generate block_size Vec4TAcc objects and accumulate them */ \ + Vec4TAcc grad_out_vec_0( \ + {%- if nobag and is_index_select %} + &grad_output[grad_offset + l_j_0 * grad_stride + d] \ + {%- elif nobag %} + &grad_output[l_j_0][d] \ + {%- elif vbe %} + &grad_output[0][grad_offset_j_0 + d] \ + {%- else %} + &grad_output[b_j_0][0] + D_start_j_0 + d \ + {%- endif %} + ); \ + {%- if nobag and is_index_select %} + REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \ + {%- elif nobag %} + REPEAT_##block_size(GRAD_VEC_N, d) \ + {%- elif vbe %} + REPEAT_##block_size(GRAD_VEC_V, d) \ + {%- else %} + REPEAT_##block_size(GRAD_VEC, d) \ + {%- endif %} + \ + {%- if weighted %} + grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \ + REPEAT_##block_size(FMA_GRAD, vec) \ + {%- else %} + grad_sum[vec].add_(grad_out_vec_0); \ + REPEAT_##block_size(ADD_GRAD, vec) \ + {%- endif %} + } \ + } + {%- if gen_once %} {#- /* The kernels in this section will be generated only once for all TBE configs @@ -141,45 +233,21 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( ? sorted_indice_weights[segment_start + sl_j] : 0.0; {%- endif %} - for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) { - {%- if nobag %} - int32_t l_j = SHFL_SYNC(l, j); - {%- elif vbe %} - const auto grad_offset_j = SHFL_SYNC(grad_offset, j); - {%- else %} - int32_t b_j = SHFL_SYNC(b, j); - int32_t D_start_j = SHFL_SYNC(D_start, j); - {%- endif %} - - {%- if weighted %} - at::acc_type idx_weight_j = SHFL_SYNC(idx_weight, j); - {%- endif %} + int32_t j = 0; - {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} - - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { - const int32_t d = {{ d }}; - Vec4TAcc grad_out_vec( - {%- if nobag and is_index_select %} - // grad_output is 1d - &grad_output[grad_offset + l_j * grad_stride + d] - {%- elif nobag %} - &grad_output[l_j][d] - {%- elif vbe %} - &grad_output[0][grad_offset_j + d] - {%- else %} - &grad_output[b_j][0] + D_start_j + d - {%- endif %} // if nobag - ); - - {%- if weighted %} - grad_sum[vec].fma_(grad_out_vec, idx_weight_j); - {%- else %} - grad_sum[vec].add_(grad_out_vec); - {%- endif %} - } - } + // Process blocks of different sizes with loop unrolling + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} From 03eeab1560275ae0f42838db45a15c1dead656c2 Mon Sep 17 00:00:00 2001 From: xzhu Date: Tue, 28 Oct 2025 13:53:58 +0000 Subject: [PATCH 9/9] split indice weights unroll improvement --- ..._backward_split_indice_weights_template.cu | 113 +++++++++++++++++- 1 file changed, 110 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 6d38d1d99a..54df8ab821 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -47,6 +51,87 @@ using namespace fbgemm_gpu; -}} }() +// Macro to process weights loop, with cache usage controlled by 'use_cache' (0 = no cache, 1 = use cache) +#define PROCESS_WEIGHTS_LOOP(use_cache, unroll_count) \ + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { \ + /* Get offset indices (common logic) */ \ + const auto offset_idx_j0 = shfl_sync(offset_idx, j); \ + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); \ + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); \ + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); \ + \ + /* Get cache indices only if use_cache is 1 (using compile-time condition) */ \ + const auto cache_idx_j0 = (use_cache) ? shfl_sync(cache_idx, j) : 0; \ + const auto cache_idx_j1 = (use_cache) ? shfl_sync(cache_idx, j+1) : 0; \ + const auto cache_idx_j2 = (use_cache) ? shfl_sync(cache_idx, j+2) : 0; \ + const auto cache_idx_j3 = (use_cache) ? shfl_sync(cache_idx, j+3) : 0; \ + \ + /* Gradient weight variables (common) */ \ + at::acc_type grad_indice_weight0 = 0.0; \ + at::acc_type grad_indice_weight1 = 0.0; \ + at::acc_type grad_indice_weight2 = 0.0; \ + at::acc_type grad_indice_weight3 = 0.0; \ + \ + /* Weight row accessors (common) */ \ + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); \ + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); \ + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); \ + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); \ + \ + /* Loop over vectors to compute gradients */ \ + for (int32_t vec = 0; vec < unroll_count && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { \ + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; \ + \ + Vec4T> weight0, weight1, weight2, weight3; \ + \ + /* Load weights: choose logic based on use_cache (compile-time condition) */ \ + if constexpr (use_cache) { \ + /* Cache-aware loading (second code snippet logic) */ \ + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? \ + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : \ + weight_row0.load(d); \ + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? \ + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : \ + weight_row1.load(d); \ + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? \ + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : \ + weight_row2.load(d); \ + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? \ + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : \ + weight_row3.load(d); \ + } else { \ + /* Direct weight loading (first code snippet logic) */ \ + weight0 = weight_row0.load(d); \ + weight1 = weight_row1.load(d); \ + weight2 = weight_row2.load(d); \ + weight3 = weight_row3.load(d); \ + } \ + \ + /* Gradient calculation (common) */ \ + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + \ + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; \ + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + \ + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; \ + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + \ + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; \ + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + \ + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; \ + } \ + \ + /* Warp reduction and result assignment (common) */ \ + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); \ + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); \ + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); \ + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); \ + \ + if (threadIdx.x == 0) { \ + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; \ + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; \ + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; \ + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; \ + } \ + } + {%- for vbe in ([True, False]) %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -98,8 +183,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void {%- endif %} ) { constexpr int32_t kVecWidth = 4; - [[maybe_unused]] int error_code = 0; - [[maybe_unused]] int64_t error_value = 0; + int error_code = 0; + int64_t error_value = 0; int32_t T = D_offsets.size(0) - 1; auto b_t = blockIdx.x * blockDim.y + threadIdx.y; @@ -210,7 +295,20 @@ __global__ __launch_bounds__(kForwardMaxThreads) void ) {%- endif %} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + int32_t j = 0; + {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + if (placement != PlacementType::MANAGED_CACHING) { + // no cache logic + #pragma unroll kFixedMaxVecsPerThread + PROCESS_WEIGHTS_LOOP(0, kFixedMaxVecsPerThread) + } else { + // with cache logic + #pragma unroll kFixedMaxVecsPerThread + PROCESS_WEIGHTS_LOOP(1, kFixedMaxVecsPerThread) + } + {%- endif %} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); @@ -359,6 +457,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0);