Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Provide informative error message on bad type input to predict#35

Open
mrocklin wants to merge 1 commit intodask:masterfrom
mrocklin:check-types
Open

Provide informative error message on bad type input to predict#35
mrocklin wants to merge 1 commit intodask:masterfrom
mrocklin:check-types

Conversation

@mrocklin
Copy link
Copy Markdown
Member

No description provided.

@mrocklin
Copy link
Copy Markdown
Member Author

mrocklin commented Feb 21, 2019 via email

@TomAugspurger
Copy link
Copy Markdown
Member

I suspect the test_sparse one is similar to what we ran into with dask. IIRC sparse changed to be stricter about not converting to dense.

No idea about the other ones unfortunately :/ Possibly something with pytest-xdist?

FWIW, I have a local (unpunished) branch called test-fixup with this diff

diff --git a/.circleci/config.yml b/.circleci/config.yml
index f1463079..72faf516 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -16,7 +16,7 @@ jobs:
             conda config --add channels conda-forge
             conda create -q -n test-environment python=${PYTHON}
             source activate test-environment
-            conda install -q coverage flake8 pytest pytest-cov pytest-xdist numpy pandas xgboost dask distributed scikit-learn sparse scipy
+            conda install -q coverage flake8 pytest pytest-cov numpy pandas xgboost dask distributed scikit-learn sparse scipy
             pip install -e .
             conda list test-environment
       - run:
diff --git a/dask_xgboost/core.py b/dask_xgboost/core.py
index 6bf29d78..c843a000 100644
--- a/dask_xgboost/core.py
+++ b/dask_xgboost/core.py
@@ -34,7 +34,7 @@ def parse_host_port(address):
     return host, port
 
 
-def start_tracker(host, n_workers):
+def start_tracker(host, n_workers, dask_scheduler=None):
     """ Start Rabit tracker """
     env = {'DMLC_NUM_WORKER': n_workers}
     rabit = RabitTracker(hostIP=host, nslave=n_workers)
@@ -45,6 +45,7 @@ def start_tracker(host, n_workers):
     thread = Thread(target=rabit.join)
     thread.daemon = True
     thread.start()
+    dask_scheduler.xgboost_thread = thread
     return env
 
 
@@ -155,6 +156,13 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
     num_class = params.get("num_class")
     if num_class:
         result.set_attr(num_class=str(num_class))
+
+    def wait_on_tracker_thread(dask_scheduler):
+        dask_scheduler.xgboost_thread.join()
+        del dask_scheduler.xgboost_thread
+
+    yield client.run_on_scheduler(wait_on_tracker_thread)
+
     raise gen.Return(result)
 
 
diff --git a/setup.cfg b/setup.cfg
index 2348f495..11894603 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,4 +5,4 @@ universal=1
 exclude = tests/data,docs,benchmarks,scripts
 
 [tool:pytest]
-addopts = -rsx -v -n 1 --boxed
+addopts = -rsx -v

Looking further, that looks like #29 (comment)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants