Skip to content

Commit b194397

Browse files
committed
Fix: Handle bracket-wrapped parameter values in XGBoost config
- Fixed extract_model_param() to handle values like '[9.9E-2]' - Added comprehensive tests for both bracket and non-bracket formats - Bumped version to 0.2.6.1 (maintenance release per PEP 440) - Fixes #5
1 parent 5d0f82f commit b194397

4 files changed

Lines changed: 150 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "xbooster"
3-
version = "0.2.6.post1"
3+
version = "0.2.6.1"
44
description = "Explainable Boosted Scoring"
55
authors = [
66
{name = "xRiskLab", email = "contact@xrisklab.ai"}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
Test module for extract_model_param with different value formats.
3+
4+
This test ensures that the extract_model_param method can handle different
5+
string formats returned by XGBoost config, including values with brackets.
6+
7+
GitHub Issue: Value of base_score parameter is read as a string wrapped in
8+
list brackets (e.g., '[9.9E-2]'), which caused a ValueError when attempting
9+
to convert it directly to a float.
10+
"""
11+
12+
import json
13+
from unittest.mock import patch
14+
15+
import pandas as pd
16+
import pytest
17+
import xgboost as xgb
18+
19+
from xbooster.xgb_constructor import XGBScorecardConstructor
20+
21+
22+
@pytest.fixture(scope="module")
23+
def simple_model():
24+
"""
25+
Creates and trains a simple XGBoost model.
26+
27+
Returns:
28+
model (xgb.XGBClassifier): Trained XGBoost model.
29+
"""
30+
X = pd.DataFrame({"feature1": [1, 2, 3, 4, 5], "feature2": [5, 4, 3, 2, 1]})
31+
y = pd.Series([0, 1, 0, 1, 0])
32+
model = xgb.XGBClassifier(n_estimators=5, max_depth=1, random_state=42)
33+
model.fit(X, y)
34+
return model
35+
36+
37+
@pytest.fixture(scope="module")
38+
def simple_data():
39+
"""
40+
Creates simple training data.
41+
42+
Returns:
43+
tuple: X and y DataFrames.
44+
"""
45+
X = pd.DataFrame({"feature1": [1, 2, 3, 4, 5], "feature2": [5, 4, 3, 2, 1]})
46+
y = pd.Series([0, 1, 0, 1, 0])
47+
return X, y
48+
49+
50+
def test_extract_model_param_with_brackets(simple_model, simple_data):
51+
"""
52+
Test that extract_model_param handles values with brackets correctly.
53+
54+
This test verifies the fix for GitHub issue where base_score came as '[9.9E-2]'
55+
instead of '9.9E-2', causing a ValueError.
56+
"""
57+
X, y = simple_data
58+
constructor = XGBScorecardConstructor(simple_model, X, y)
59+
60+
# Test with bracket format
61+
mock_config = {
62+
"learner": {
63+
"learner_model_param": {"base_score": "[9.9E-2]"},
64+
"gradient_booster": {
65+
"tree_train_param": {"learning_rate": "[0.3]", "max_depth": "[1]"}
66+
},
67+
}
68+
}
69+
70+
with patch.object(constructor.booster_, "save_config", return_value=json.dumps(mock_config)):
71+
base_score = constructor.extract_model_param("base_score")
72+
learning_rate = constructor.extract_model_param("learning_rate")
73+
max_depth = constructor.extract_model_param("max_depth")
74+
75+
assert abs(base_score - 0.099) < 1e-6
76+
assert abs(learning_rate - 0.3) < 1e-6
77+
assert abs(max_depth - 1.0) < 1e-6
78+
79+
80+
def test_extract_model_param_without_brackets(simple_model, simple_data):
81+
"""
82+
Test that extract_model_param handles values without brackets correctly.
83+
84+
This ensures backward compatibility with the standard format.
85+
"""
86+
X, y = simple_data
87+
constructor = XGBScorecardConstructor(simple_model, X, y)
88+
89+
# Test without bracket format (standard)
90+
mock_config = {
91+
"learner": {
92+
"learner_model_param": {"base_score": "9.9E-2"},
93+
"gradient_booster": {"tree_train_param": {"learning_rate": "0.3", "max_depth": "1"}},
94+
}
95+
}
96+
97+
with patch.object(constructor.booster_, "save_config", return_value=json.dumps(mock_config)):
98+
base_score = constructor.extract_model_param("base_score")
99+
learning_rate = constructor.extract_model_param("learning_rate")
100+
max_depth = constructor.extract_model_param("max_depth")
101+
102+
assert abs(base_score - 0.099) < 1e-6
103+
assert abs(learning_rate - 0.3) < 1e-6
104+
assert abs(max_depth - 1.0) < 1e-6
105+
106+
107+
def test_extract_model_param_various_formats(simple_model, simple_data):
108+
"""
109+
Test that extract_model_param handles various numeric string formats.
110+
"""
111+
X, y = simple_data
112+
constructor = XGBScorecardConstructor(simple_model, X, y)
113+
114+
test_cases = [
115+
("5E-1", 0.5),
116+
("[5E-1]", 0.5),
117+
("1.0", 1.0),
118+
("[1.0]", 1.0),
119+
("0.099", 0.099),
120+
("[0.099]", 0.099),
121+
]
122+
123+
for value_str, expected in test_cases:
124+
mock_config = {
125+
"learner": {
126+
"learner_model_param": {"base_score": value_str},
127+
"gradient_booster": {"tree_train_param": {"learning_rate": "0.3"}},
128+
}
129+
}
130+
131+
with patch.object(
132+
constructor.booster_, "save_config", return_value=json.dumps(mock_config)
133+
):
134+
result = constructor.extract_model_param("base_score")
135+
136+
assert abs(result - expected) < 1e-6, (
137+
f"Failed for {value_str}: got {result}, expected {expected}"
138+
)

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

xbooster/xgb_constructor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,16 @@ def extract_model_param(self, param):
109109
"""
110110
config = json.loads(self.booster_.save_config())
111111
if param == "base_score":
112-
return float(config["learner"]["learner_model_param"][param])
113-
return float(config["learner"]["gradient_booster"]["tree_train_param"][param])
112+
value = config["learner"]["learner_model_param"][param]
113+
else:
114+
value = config["learner"]["gradient_booster"]["tree_train_param"][param]
115+
116+
# Handle different value formats from XGBoost config
117+
if isinstance(value, str):
118+
# Remove brackets if present (e.g., '[9.9E-2]' -> '9.9E-2')
119+
value = value.strip("[]")
120+
121+
return float(value)
114122

115123
def add_detailed_split(self, dataframe: Optional[pd.DataFrame] = None) -> pd.DataFrame:
116124
"""

0 commit comments

Comments
 (0)