diff --git a/xgboost_explainer.py b/xgboost_explainer.py index 9625beb..18db431 100644 --- a/xgboost_explainer.py +++ b/xgboost_explainer.py @@ -31,7 +31,10 @@ def model2table(bst, eta=0.3, lmda=1.0): parent = {} parent[0] = None lst_node_str = line.split('\n') - node_lst = [{} for _ in range(len(lst_node_str)-1)] +#---------I have found some cases where the lines in lst_node_str do not match the max node idx----- + max_node_idx = max([int(node[:node.index(":")]) for node in lst_node_str if len(node) > 0]) + node_lst = [{} for _ in range(max_node_idx+1)] +#--------------------------------------------------------------------------------------------------- for node in lst_node_str: node = node.strip() # print("fdfdf",len(node)) @@ -73,12 +76,14 @@ def model2table(bst, eta=0.3, lmda=1.0): # node_lst.append(d) node_lst[node_idx] = d for j, node in enumerate(node_lst): - node_lst[j]['parent'] = parent[node_lst[j]['node']] + if bool(node): + node_lst[j]['parent'] = parent[node_lst[j]['node']] tree_lst[i] = node_lst for t in tree_lst: check_params(t, eta, lmda) for j in reversed(range(len(t))): node = t[j] + if not bool(node): continue if node['is_leaf']: G = -1.*node['leaf']*(node['cover']+lmda)/eta else: @@ -88,6 +93,7 @@ def model2table(bst, eta=0.3, lmda=1.0): for t in tree_lst: for j in reversed(range(len(t))): node = t[j] + if not bool(node): continue if node['parent'] is None: node['logit_delta'] = node['logit'] - .0 else: