From d15df541f05efd2e707dac1674cefaec101552c6 Mon Sep 17 00:00:00 2001 From: Christiaan Date: Thu, 15 Feb 2018 13:10:04 +0200 Subject: [PATCH 1/3] Update xgboost_explainer.py I have found some cases where the lines in lst_node_str do not match the max node idx in the tree, and then the code breaks at node_lst[node_idx] = d --- xgboost_explainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xgboost_explainer.py b/xgboost_explainer.py index 9625beb..7d1ecb7 100644 --- a/xgboost_explainer.py +++ b/xgboost_explainer.py @@ -31,7 +31,10 @@ def model2table(bst, eta=0.3, lmda=1.0): parent = {} parent[0] = None lst_node_str = line.split('\n') - node_lst = [{} for _ in range(len(lst_node_str)-1)] +#---------I have found some cases where the lines in lst_node_str do not match the max node idx----- + max_node_idx = max([int(node[:node.index(":")]) for node in lst_node_str if len(node) > 0]) + node_lst = [{} for _ in range(max_node_idx+1)] +#--------------------------------------------------------------------------------------------------- for node in lst_node_str: node = node.strip() # print("fdfdf",len(node)) From c0b9aee4ecac1a1c3718d82adf7c4f7b5fe1d1a8 Mon Sep 17 00:00:00 2001 From: Christiaan Date: Thu, 15 Feb 2018 15:02:29 +0200 Subject: [PATCH 2/3] Update xgboost_explainer.py with my specific example I still had index errors at two places...corrected these --- xgboost_explainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xgboost_explainer.py b/xgboost_explainer.py index 7d1ecb7..eed679e 100644 --- a/xgboost_explainer.py +++ b/xgboost_explainer.py @@ -82,6 +82,7 @@ def model2table(bst, eta=0.3, lmda=1.0): check_params(t, eta, lmda) for j in reversed(range(len(t))): node = t[j] + if not bool(node): continue if node['is_leaf']: G = -1.*node['leaf']*(node['cover']+lmda)/eta else: @@ -91,6 +92,7 @@ def model2table(bst, eta=0.3, lmda=1.0): for t in tree_lst: for j in reversed(range(len(t))): node = t[j] + if not bool(node): continue if node['parent'] is None: node['logit_delta'] = node['logit'] - .0 else: From acdc0aa8a5ec611bdf84ebf607133e3d65bf2d87 Mon Sep 17 00:00:00 2001 From: Christiaan Date: Thu, 15 Feb 2018 15:21:38 +0200 Subject: [PATCH 3/3] Update xgboost_explainer.py --- xgboost_explainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xgboost_explainer.py b/xgboost_explainer.py index eed679e..18db431 100644 --- a/xgboost_explainer.py +++ b/xgboost_explainer.py @@ -76,7 +76,8 @@ def model2table(bst, eta=0.3, lmda=1.0): # node_lst.append(d) node_lst[node_idx] = d for j, node in enumerate(node_lst): - node_lst[j]['parent'] = parent[node_lst[j]['node']] + if bool(node): + node_lst[j]['parent'] = parent[node_lst[j]['node']] tree_lst[i] = node_lst for t in tree_lst: check_params(t, eta, lmda)