From 5200c54221c5a29816f735e053829d812a5351fa Mon Sep 17 00:00:00 2001 From: Erik Ordentlich Date: Tue, 4 Nov 2025 14:51:54 -0800 Subject: [PATCH] clear up confusing error message for non-contiguous labels in rf classifier and add docstring note about this requirement Signed-off-by: Erik Ordentlich --- python/src/spark_rapids_ml/classification.py | 7 +++++++ python/src/spark_rapids_ml/tree.py | 14 +++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index 46b5dae4..f78f8527 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -418,6 +418,13 @@ class RandomForestClassifier( max_batch_size: int (default = 4096) Maximum number of nodes that can be processed in a given batch. + Notes + ----- + The label column is required to be an integer in the range ``0, 1, ..., num_classes - 1``. Moreover, for fit() to succeed, + all values in this range are required to be present in the input data and also each worker must receive the full range of values. + If this is not the case, an error will be raised with possible work arounds being to remap the labels to the expected range, + increase the number of very rare label occurrences in the input data, rerun with fewer workers, or shuffle the input data. + Examples -------- >>> import numpy diff --git a/python/src/spark_rapids_ml/tree.py b/python/src/spark_rapids_ml/tree.py index f6674942..10d66949 100644 --- a/python/src/spark_rapids_ml/tree.py +++ b/python/src/spark_rapids_ml/tree.py @@ -413,11 +413,13 @@ def _single_fit(rf: cuRf) -> Dict[str, Any]: # Fit a random forest model on the dataset (X, y) rf.fit(X, y, convert_dtype=False) + missing_labels_error_message = "A GPU worker did not receive all label values in the range 0, 1, ..., num_classes - 1, which is currently required. \ + Depending on the root cause, possible work arounds are to remap the labels to the required range, increase the number \ + of very rare label occurrences in the input data, rerun with fewer workers, or shuffle the input data." + if is_classification: if rf.classes_.max() != rf.n_classes_ - 1: - raise RuntimeError( - "A GPU worker did not receive all label values. Rerun with fewer workers or shuffle input data." - ) + raise RuntimeError(missing_labels_error_message) # serialized_model is Dictionary type serialized_model = rf._treelite_model_bytes @@ -451,11 +453,9 @@ def _single_fit(rf: cuRf) -> Dict[str, Any]: exc_str = traceback.format_exc() if "different num_class than the first model object" in exc_str: - raise RuntimeError( - "Some GPU workers did not receive all label values. Rerun with fewer workers or shuffle input data." - ) + raise RuntimeError(missing_labels_error_message) else: - raise err + raise final_model_bytes = pickle.dumps(_treelite_model_bytes) final_model = base64.b64encode(final_model_bytes).decode("utf-8")