Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions dask_xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
sparse = False
ss = False

import dask
from dask import delayed
from dask.distributed import wait, default_client
import dask.dataframe as dd
Expand All @@ -26,6 +27,13 @@
logger = logging.getLogger(__name__)


def maybe_get_client():
try:
return default_client()
except ValueError:
return None


def parse_host_port(address):
if '://' in address:
address = address.rsplit('://', 1)[1]
Expand Down Expand Up @@ -187,6 +195,12 @@ def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
--------
predict
"""
if (client is None or
not dask.is_dask_collection(data) or
not dask.is_dask_collection(labels)):
dtrain = xgb.DMatrix(data, labels, **dmatrix_kwargs)
return xgb.train(params, dtrain, **kwargs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also raise here and point the user to xgb.train with an informative message.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still debugging this, but it seems like in a Randomized/GridSearchCV context we can end up with a concrete ndarray here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think https://github.com/dask/dask-xgboost/pull/28/files#r214392082 left out the important bit of information that GridSearchCV can pass a concrete ndarary here, even if you start with a dask array.


return client.sync(_train, client, params, data,
labels, dmatrix_kwargs, **kwargs)

Expand Down Expand Up @@ -233,7 +247,10 @@ def predict(client, model, data):
--------
train
"""
if isinstance(data, dd._Frame):
if client is None or isinstance(data, (np.ndarray, pd.DataFrame)):
dm = xgb.DMatrix(data)
result = model.predict(dm)
elif isinstance(data, dd._Frame):
result = data.map_partitions(_predict_part, model=model)
result = result.values
elif isinstance(data, da.Array):
Expand Down Expand Up @@ -276,14 +293,14 @@ def fit(self, X, y=None):
``eval_metric``, ``early_stopping_rounds`` and ``verbose`` fit
kwargs.
"""
client = default_client()
client = maybe_get_client()
xgb_options = self.get_xgb_params()
self._Booster = train(client, xgb_options, X, y,
num_boost_round=self.n_estimators)
return self

def predict(self, X):
client = default_client()
client = maybe_get_client()
return predict(client, self._Booster, X)


Expand Down Expand Up @@ -316,14 +333,16 @@ def fit(self, X, y=None, classes=None):
2. The labels are not automatically label-encoded
3. The ``classes_`` and ``n_classes_`` attributes are not learned
"""
client = default_client()
client = maybe_get_client()

if classes is None:
if isinstance(y, da.Array):
if isinstance(y, np.ndarray):
classes = np.unique(classes)
elif isinstance(y, da.Array):
classes = da.unique(y)
else:
classes = y.unique()
classes = classes.compute()
classes = dask.compute(classes)
else:
classes = np.asarray(classes)
self.classes_ = classes
Expand All @@ -346,16 +365,20 @@ def fit(self, X, y=None, classes=None):
# TODO: auto label-encode y
# that will require a dependency on dask-ml
# TODO: sample weight

self._Booster = train(client, xgb_options, X, y,
num_boost_round=self.n_estimators)
bst = train(client, xgb_options, X, y,
num_boost_round=self.n_estimators)
self._Booster = bst
return self

def predict(self, X):
client = default_client()
client = maybe_get_client()

class_probs = predict(client, self._Booster, X)
if class_probs.ndim > 1:
cidx = da.argmax(class_probs, axis=1)
if isinstance(class_probs, (pd.DataFrame, np.ndarray)):
cidx = np.argmax(class_probs, axis=1)
else:
cidx = da.argmax(class_probs, axis=1)
else:
cidx = (class_probs > 0).astype(np.int64)
return cidx
Expand Down
22 changes: 22 additions & 0 deletions dask_xgboost/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,25 @@ def f(part):
yield dxgb.train(c, param, df, df.x)

assert 'foo' in str(info.value)


@gen_cluster(client=True, timeout=None, check_new_threads=False)
def test_concrete(c, s, a, b):
for est in [dxgb.XGBClassifier(), dxgb.XGBRegressor()]:
est.fit(X, y)
result = est.predict(X)
assert isinstance(result, np.ndarray)
est.score(X, y)


def test_dask_search_cv(loop): # noqa

with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop):
model_selection = pytest.importorskip('dask_ml.model_selection')
est = dxgb.XGBClassifier()
cv = model_selection.RandomizedSearchCV(est,
{'max_depth': [1, 10]})
dX = da.from_array(X, 5)
dy = da.from_array(y, 5)
cv.fit(dX, dy)